Skip to content

Commit 9df7f76

Browse files
committed
Specify multitask trial_type as discrete variable
1 parent 07eee24 commit 9df7f76

File tree

3 files changed

+51
-25
lines changed

3 files changed

+51
-25
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
TrialStatus,
6868
)
6969
from .ax_metric import AxMetric
70-
from generator_standard.vocs import VOCS
70+
from generator_standard.vocs import VOCS, DiscreteVariable
7171

7272
# Define generator states.
7373
NOT_STARTED = "not_started"
@@ -193,13 +193,11 @@ def __init__(
193193
model_save_period: Optional[int] = 5,
194194
model_history_dir: Optional[str] = "model_history",
195195
) -> None:
196-
# As trial parameters these get written to history array
197-
# Ax trial_index and arm toegther locate a point
198-
# Multiple points (Optimas trials) can share the same Ax trial_index
199-
custom_trial_parameters = [
200-
TrialParameter("trial_type", "ax_trial_type", dtype="U32"),
201-
]
202196
self._check_inputs(vocs, lofi_task, hifi_task)
197+
198+
# Convert discrete variables to trial parameters before calling super().__init__
199+
custom_trial_parameters = self._convert_discrete_variables_to_trial_parameters(vocs)
200+
203201
super().__init__(
204202
vocs=vocs,
205203
use_cuda=use_cuda,
@@ -223,11 +221,22 @@ def __init__(
223221
self.current_trial = None
224222
self.gr_lofi = None
225223
self._experiment = self._create_experiment()
226-
224+
227225
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228226
self._id_mapping = {}
229227
self._next_id = 0
230228

229+
def _convert_discrete_variables_to_trial_parameters(self, vocs: VOCS) -> List[TrialParameter]:
230+
"""Convert discrete variables from VOCS to TrialParameter objects."""
231+
trial_parameters = []
232+
for var_name, var_spec in vocs.variables.items():
233+
if isinstance(var_spec, DiscreteVariable):
234+
# Convert discrete variable to trial parameter
235+
max_len = max(len(str(val)) for val in var_spec.values)
236+
trial_param = TrialParameter(var_name, var_name, dtype=f"U{max_len}")
237+
trial_parameters.append(trial_param)
238+
return trial_parameters
239+
231240
def get_gen_specs(
232241
self, sim_workers: int, run_params: Dict, sim_max: int
233242
) -> Dict:
@@ -239,19 +248,30 @@ def get_gen_specs(
239248
gen_specs["out"].append(("task", str, max_length))
240249
return gen_specs
241250

251+
def _validate_vocs(self, vocs: VOCS) -> None:
252+
"""Validate VOCS for multitask generator."""
253+
super()._validate_vocs(vocs)
254+
# Check that only one objective has been given.
255+
n_objectives = len(vocs.objectives)
256+
assert n_objectives == 1, (
257+
"Multitask generator supports only a single objective. "
258+
"Objectives given: {}.".format(n_objectives)
259+
)
260+
# Check that there is a discrete variable called 'trial_type'
261+
assert "trial_type" in vocs.variables, (
262+
"Multitask generator requires a discrete variable named 'trial_type'"
263+
)
264+
assert isinstance(vocs.variables["trial_type"], DiscreteVariable), (
265+
"Variable 'trial_type' must be a discrete variable"
266+
)
267+
242268
def _check_inputs(
243269
self,
244270
vocs: VOCS,
245271
lofi_task: Task,
246272
hifi_task: Task,
247273
) -> None:
248274
"""Check that the generator inputs are valid."""
249-
# Check that only one objective has been given.
250-
n_objectives = len(vocs.objectives)
251-
assert n_objectives == 1, (
252-
"Multitask generator supports only a single objective. "
253-
"Objectives given: {}.".format(n_objectives)
254-
)
255275
# Check that the number of low-fidelity trials per iteration is larger
256276
# than that of high-fidelity trials.
257277
assert lofi_task.n_opt >= hifi_task.n_opt, (
@@ -270,6 +290,12 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
270290
var.name: arm.parameters.get(var.name)
271291
for var in self._varying_parameters
272292
}
293+
# SH We can use a discrete var here in vocs (converted for now to trial parameters)
294+
# But unlike varying parameters the name refers to a fixed generator concept.
295+
for trial_param in self._custom_trial_parameters:
296+
if trial_param.name == "trial_type":
297+
point[trial_param.name] = trial_type
298+
273299
# Generate unique _id and store mapping
274300
current_id = self._next_id
275301
self._id_mapping[current_id] = {
@@ -278,7 +304,6 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
278304
"trial_type": trial_type,
279305
}
280306
point["_id"] = current_id
281-
point["trial_type"] = trial_type # Keep trial_type for now
282307
self._next_id += 1
283308
points.append(point)
284309
return points

optimas/generators/base.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TrialParameter,
2121
TrialStatus,
2222
)
23-
from generator_standard.vocs import VOCS
23+
from generator_standard.vocs import VOCS, ContinuousVariable
2424
from generator_standard.generator import Generator as StandardGenerator
2525

2626
logger = get_logger(__name__)
@@ -127,13 +127,14 @@ def _convert_vocs_variables_to_varying_parameters(
127127
varying_parameters = []
128128
for var_name, var_spec in self._vocs.variables.items():
129129
# Only handle ContinuousVariable for now
130-
vp = VaryingParameter(
131-
name=var_name,
132-
lower_bound=var_spec.domain[0],
133-
upper_bound=var_spec.domain[1],
134-
default_value=var_spec.default_value,
135-
)
136-
varying_parameters.append(vp)
130+
if isinstance(var_spec, ContinuousVariable):
131+
vp = VaryingParameter(
132+
name=var_name,
133+
lower_bound=var_spec.domain[0],
134+
upper_bound=var_spec.domain[1],
135+
default_value=var_spec.default_value,
136+
)
137+
varying_parameters.append(vp)
137138
return varying_parameters
138139

139140
def _convert_vocs_objectives_to_objectives(self) -> List[Objective]:

tests/test_ax_generators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def test_ax_multitask():
534534
"""Test that an exploration with a multitask generator runs"""
535535

536536
vocs = VOCS(
537-
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0]},
537+
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0], "trial_type": {"task_1", "task_2"}},
538538
objectives={"f": "MAXIMIZE"},
539539
)
540540

@@ -721,7 +721,7 @@ def test_ax_multitask_with_history():
721721
"""
722722

723723
vocs = VOCS(
724-
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0]},
724+
variables={"x0": [-50.0, 5.0], "x1": [-5.0, 15.0], "trial_type": {"task_1", "task_2"}},
725725
objectives={"f": "MAXIMIZE"},
726726
)
727727

0 commit comments

Comments
 (0)