Skip to content

Commit ce09d6b

Browse files
committed
Update multitask for vocs and apply id mapping
1 parent 9b60771 commit ce09d6b

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
TrialStatus,
6868
)
6969
from .ax_metric import AxMetric
70+
from generator_standard.vocs import VOCS
7071

7172
# Define generator states.
7273
NOT_STARTED = "not_started"
@@ -152,10 +153,8 @@ class AxMultitaskGenerator(AxGenerator):
152153
153154
Parameters
154155
----------
155-
varying_parameters : list of VaryingParameter
156-
List of input parameters to vary. One them should be a fidelity.
157-
objectives : list of Objective
158-
List of optimization objectives. Only one objective is supported.
156+
vocs : VOCS
157+
VOCS object defining variables, objectives, constraints, and observables.
159158
lofi_task, hifi_task : Task
160159
The low- and high-fidelity tasks.
161160
analyzed_parameters : list of Parameter, optional
@@ -184,11 +183,9 @@ class AxMultitaskGenerator(AxGenerator):
184183

185184
def __init__(
186185
self,
187-
varying_parameters: List[VaryingParameter],
188-
objectives: List[Objective],
186+
vocs: VOCS,
189187
lofi_task: Task,
190188
hifi_task: Task,
191-
analyzed_parameters: Optional[List[Parameter]] = None,
192189
use_cuda: Optional[bool] = False,
193190
gpu_id: Optional[int] = 0,
194191
dedicated_resources: Optional[bool] = False,
@@ -200,15 +197,11 @@ def __init__(
200197
# Ax trial_index and arm toegther locate a point
201198
# Multiple points (Optimas trials) can share the same Ax trial_index
202199
custom_trial_parameters = [
203-
TrialParameter("arm_name", "ax_arm_name", dtype="U32"),
204200
TrialParameter("trial_type", "ax_trial_type", dtype="U32"),
205-
TrialParameter("ax_trial_id", "ax_trial_index", dtype=int),
206201
]
207-
self._check_inputs(varying_parameters, objectives, lofi_task, hifi_task)
202+
self._check_inputs(vocs, lofi_task, hifi_task)
208203
super().__init__(
209-
varying_parameters=varying_parameters,
210-
objectives=objectives,
211-
analyzed_parameters=analyzed_parameters,
204+
vocs=vocs,
212205
use_cuda=use_cuda,
213206
gpu_id=gpu_id,
214207
dedicated_resources=dedicated_resources,
@@ -230,6 +223,10 @@ def __init__(
230223
self.current_trial = None
231224
self.gr_lofi = None
232225
self._experiment = self._create_experiment()
226+
227+
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228+
self._id_mapping = {}
229+
self._next_id = 0
233230

234231
def get_gen_specs(
235232
self, sim_workers: int, run_params: Dict, sim_max: int
@@ -244,14 +241,13 @@ def get_gen_specs(
244241

245242
def _check_inputs(
246243
self,
247-
varying_parameters: List[VaryingParameter],
248-
objectives: List[Objective],
244+
vocs: VOCS,
249245
lofi_task: Task,
250246
hifi_task: Task,
251247
) -> None:
252248
"""Check that the generator inputs are valid."""
253249
# Check that only one objective has been given.
254-
n_objectives = len(objectives)
250+
n_objectives = len(vocs.objectives)
255251
assert n_objectives == 1, (
256252
"Multitask generator supports only a single objective. "
257253
"Objectives given: {}.".format(n_objectives)
@@ -274,11 +270,16 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
274270
var.name: arm.parameters.get(var.name)
275271
for var in self._varying_parameters
276272
}
277-
# SH for VOCS standard these will need to be 'variables'
278-
# For now much match the trial parameter names.
279-
point["ax_trial_id"] = trial_index
280-
point["arm_name"] = arm.name
281-
point["trial_type"] = trial_type
273+
# Generate unique _id and store mapping
274+
current_id = self._next_id
275+
self._id_mapping[current_id] = {
276+
"arm_name": arm.name,
277+
"ax_trial_id": trial_index,
278+
"trial_type": trial_type
279+
}
280+
point["_id"] = current_id
281+
point["trial_type"] = trial_type # Keep trial_type for now
282+
self._next_id += 1
282283
points.append(point)
283284
return points
284285

@@ -295,6 +296,15 @@ def ingest(self, results: List[dict]) -> None:
295296
custom_parameters=self._custom_trial_parameters,
296297
)
297298
trials.append(trial)
299+
300+
# Apply _id mapping to all trials before processing
301+
for trial in trials:
302+
if trial.gen_id is not None and trial.gen_id in self._id_mapping:
303+
mapping = self._id_mapping[trial.gen_id]
304+
trial.arm_name = mapping["arm_name"]
305+
trial.ax_trial_id = mapping["ax_trial_id"]
306+
# trial_type should already be in trial from custom_parameters
307+
298308
if self.gen_state == NOT_STARTED:
299309
self._incorporate_external_data(trials)
300310
else:

0 commit comments

Comments
 (0)