Skip to content

Commit e516585

Browse files
committed
Update multitask test for standard
1 parent 01f712b commit e516585

File tree

1 file changed

+40
-16
lines changed

1 file changed

+40
-16
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,18 @@ def __init__(
196196
model_save_period: Optional[int] = 5,
197197
model_history_dir: Optional[str] = "model_history",
198198
) -> None:
199+
# SH note for standardization
200+
# As trial parameters these get written to history array
201+
# Ax trial_index and arm toegther locate a point
202+
# Multiple points (Optimas trials) can share the same Ax trial_index
203+
# Standard (VOCS) does not have equiv. of trial parameters -> would be variables.
204+
# If want to use _id can have a mapping inside the generator, but those
205+
# points that are part of same ax trial_index will not be seen in history.
199206
custom_trial_parameters = [
200207
TrialParameter("arm_name", "ax_arm_name", dtype="U32"),
201208
TrialParameter("trial_type", "ax_trial_type", dtype="U32"),
202-
TrialParameter("trial_index", "ax_trial_index", dtype=int),
209+
# SH changed from trial_index to ax_trial_id - trial_index is usually Optimas index?
210+
TrialParameter("ax_trial_id", "ax_trial_index", dtype=int),
203211
]
204212
self._check_inputs(varying_parameters, objectives, lofi_task, hifi_task)
205213
super().__init__(
@@ -259,24 +267,39 @@ def _check_inputs(
259267
"The number of low-fidelity trials must be larger than or equal "
260268
"to the number of high-fidelity trials"
261269
)
262-
263-
def suggest(self, trials: List[Trial]) -> List[Trial]:
264-
"""Fill in the parameter values of the requested trials."""
265-
for trial in trials:
270+
271+
def suggest(self, num_points: Optional[int]) -> List[dict]:
272+
"""Request the next set of points to evaluate."""
273+
points = []
274+
for _ in range(num_points):
266275
next_trial = self._get_next_trial_arm()
267276
if next_trial is not None:
268-
arm, trial_type, trial_index = next_trial
269-
trial.parameter_values = [
270-
arm.parameters.get(var.name)
271-
for var in self._varying_parameters
272-
]
273-
trial.trial_type = trial_type
274-
trial.arm_name = arm.name
275-
trial.trial_index = trial_index
276-
return trials
277-
278-
def ingest(self, trials: List[Trial]) -> None:
277+
arm, trial_type, trial_index = next_trial
278+
point = {
279+
var.name: arm.parameters.get(var.name)
280+
for var in self._varying_parameters
281+
}
282+
# SH for VOCS standard these will need to be declared as variables
283+
# For now much match the trial parameter names.
284+
point["ax_trial_id"] = trial_index
285+
point["arm_name"] = arm.name
286+
point["trial_type"] = trial_type
287+
points.append(point)
288+
return points
289+
290+
def ingest(self, results: List[dict]) -> None:
279291
"""Incorporate evaluated trials into experiment."""
292+
# reconstruct Optimastrials
293+
trials = []
294+
for result in results:
295+
trial = Trial.from_dict(
296+
trial_dict=result,
297+
varying_parameters=self._varying_parameters,
298+
objectives=self._objectives,
299+
analyzed_parameters=self._analyzed_parameters,
300+
custom_parameters=self._custom_trial_parameters,
301+
)
302+
trials.append(trial)
280303
if self.gen_state == NOT_STARTED:
281304
self._incorporate_external_data(trials)
282305
else:
@@ -285,6 +308,7 @@ def ingest(self, trials: List[Trial]) -> None:
285308
def _incorporate_external_data(self, trials: List[Trial]) -> None:
286309
"""Incorporate external data (e.g., from history) into experiment."""
287310
# Get trial indices.
311+
# SH should have handling if trial_indexs are None...
288312
trial_indices = []
289313
for trial in trials:
290314
trial_indices.append(trial.trial_index)

0 commit comments

Comments
 (0)