Skip to content

Commit bbc23bc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9df7f76 commit bbc23bc

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ def __init__(
194194
model_history_dir: Optional[str] = "model_history",
195195
) -> None:
196196
self._check_inputs(vocs, lofi_task, hifi_task)
197-
197+
198198
# Convert discrete variables to trial parameters before calling super().__init__
199-
custom_trial_parameters = self._convert_discrete_variables_to_trial_parameters(vocs)
200-
199+
custom_trial_parameters = (
200+
self._convert_discrete_variables_to_trial_parameters(vocs)
201+
)
202+
201203
super().__init__(
202204
vocs=vocs,
203205
use_cuda=use_cuda,
@@ -221,19 +223,23 @@ def __init__(
221223
self.current_trial = None
222224
self.gr_lofi = None
223225
self._experiment = self._create_experiment()
224-
226+
225227
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
226228
self._id_mapping = {}
227229
self._next_id = 0
228230

229-
def _convert_discrete_variables_to_trial_parameters(self, vocs: VOCS) -> List[TrialParameter]:
231+
def _convert_discrete_variables_to_trial_parameters(
232+
self, vocs: VOCS
233+
) -> List[TrialParameter]:
230234
"""Convert discrete variables from VOCS to TrialParameter objects."""
231235
trial_parameters = []
232236
for var_name, var_spec in vocs.variables.items():
233237
if isinstance(var_spec, DiscreteVariable):
234238
# Convert discrete variable to trial parameter
235239
max_len = max(len(str(val)) for val in var_spec.values)
236-
trial_param = TrialParameter(var_name, var_name, dtype=f"U{max_len}")
240+
trial_param = TrialParameter(
241+
var_name, var_name, dtype=f"U{max_len}"
242+
)
237243
trial_parameters.append(trial_param)
238244
return trial_parameters
239245

@@ -258,12 +264,12 @@ def _validate_vocs(self, vocs: VOCS) -> None:
258264
"Objectives given: {}.".format(n_objectives)
259265
)
260266
# Check that there is a discrete variable called 'trial_type'
261-
assert "trial_type" in vocs.variables, (
262-
"Multitask generator requires a discrete variable named 'trial_type'"
263-
)
264-
assert isinstance(vocs.variables["trial_type"], DiscreteVariable), (
265-
"Variable 'trial_type' must be a discrete variable"
266-
)
267+
assert (
268+
"trial_type" in vocs.variables
269+
), "Multitask generator requires a discrete variable named 'trial_type'"
270+
assert isinstance(
271+
vocs.variables["trial_type"], DiscreteVariable
272+
), "Variable 'trial_type' must be a discrete variable"
267273

268274
def _check_inputs(
269275
self,
@@ -295,7 +301,7 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
295301
for trial_param in self._custom_trial_parameters:
296302
if trial_param.name == "trial_type":
297303
point[trial_param.name] = trial_type
298-
304+
299305
# Generate unique _id and store mapping
300306
current_id = self._next_id
301307
self._id_mapping[current_id] = {

tests/test_ax_generators.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,11 @@ def test_ax_multitask():
534534
"""Test that an exploration with a multitask generator runs"""
535535

536536
vocs = VOCS(
537-
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0], "trial_type": {"task_1", "task_2"}},
537+
variables={
538+
"x0": [-50.0, 5.0],
539+
"x1": [-5.0, 15.0],
540+
"trial_type": {"task_1", "task_2"},
541+
},
538542
objectives={"f": "MAXIMIZE"},
539543
)
540544

@@ -721,7 +725,11 @@ def test_ax_multitask_with_history():
721725
"""
722726

723727
vocs = VOCS(
724-
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0], "trial_type": {"task_1", "task_2"}},
728+
variables={
729+
"x0": [-50.0, 5.0],
730+
"x1": [-5.0, 15.0],
731+
"trial_type": {"task_1", "task_2"},
732+
},
725733
objectives={"f": "MAXIMIZE"},
726734
)
727735

0 commit comments

Comments
 (0)