Skip to content

Commit ff92723

Browse files
committed
Handle discrete variables in gen base class
1 parent b91a7db commit ff92723

File tree

2 files changed

+31
-20
lines changed

2 files changed

+31
-20
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,6 @@ def __init__(
205205
]
206206
self._check_inputs(vocs, lofi_task, hifi_task)
207207

208-
# Convert discrete variables to trial parameters before calling super().__init__
209-
custom_trial_parameters.extend(
210-
self._convert_discrete_variables_to_trial_parameters(vocs)
211-
)
212-
213208
super().__init__(
214209
vocs=vocs,
215210
use_cuda=use_cuda,
@@ -234,21 +229,6 @@ def __init__(
234229
self.gr_lofi = None
235230
self._experiment = self._create_experiment()
236231

237-
def _convert_discrete_variables_to_trial_parameters(
238-
self, vocs: VOCS
239-
) -> List[TrialParameter]:
240-
"""Convert discrete variables from VOCS to TrialParameter objects."""
241-
trial_parameters = []
242-
for var_name, var_spec in vocs.variables.items():
243-
if isinstance(var_spec, DiscreteVariable):
244-
# Convert discrete variable to trial parameter
245-
max_len = max(len(str(val)) for val in var_spec.values)
246-
trial_param = TrialParameter(
247-
var_name, var_name, dtype=f"U{max_len}"
248-
)
249-
trial_parameters.append(trial_param)
250-
return trial_parameters
251-
252232
def get_gen_specs(
253233
self, sim_workers: int, run_params: Dict, sim_max: int
254234
) -> Dict:

optimas/generators/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def __init__(
105105
self._custom_trial_parameters = (
106106
[] if custom_trial_parameters is None else custom_trial_parameters
107107
)
108+
109+
# Automatically add discrete variables as trial parameters
110+
discrete_trial_params = (
111+
self._convert_vocs_discrete_variables_to_trial_parameters()
112+
)
113+
self._custom_trial_parameters.extend(discrete_trial_params)
108114
self._allow_fixed_parameters = allow_fixed_parameters
109115
self._allow_updating_parameters = allow_updating_parameters
110116
self._gen_function = persistent_generator
@@ -191,6 +197,31 @@ def _convert_vocs_observables_to_parameters(self) -> List[Parameter]:
191197
parameters.append(param)
192198
return parameters
193199

200+
def _convert_vocs_discrete_variables_to_trial_parameters(
201+
self,
202+
) -> List[TrialParameter]:
203+
"""Convert discrete variables from VOCS to TrialParameter objects.
204+
205+
Only converts discrete variables that were NOT already converted to
206+
VaryingParameters.
207+
"""
208+
trial_parameters = []
209+
# Get the names of variables that were already converted to
210+
# VaryingParameters
211+
varying_param_names = {vp.name for vp in self._varying_parameters}
212+
213+
for var_name, var_spec in self._vocs.variables.items():
214+
if isinstance(var_spec, DiscreteVariable):
215+
# Only convert if it wasn't already converted to a
216+
# VaryingParameter
217+
if var_name not in varying_param_names:
218+
max_len = max(len(str(val)) for val in var_spec.values)
219+
trial_param = TrialParameter(
220+
var_name, var_name, dtype=f"U{max_len}"
221+
)
222+
trial_parameters.append(trial_param)
223+
return trial_parameters
224+
194225
@property
195226
def vocs(self) -> VOCS:
196227
"""Get the VOCS object."""

0 commit comments

Comments
 (0)