Skip to content

Commit f883c7b

Browse files
committed
Revert multitask to preserve ax id/arm info.
* This way they are stored and read in from previous history * They are not part of vocs * Using TrialParameter uses already setup format in Optimas * They get put in the libE history array * trial_type is a vocs discrete, but uses TrialParameter internally
1 parent 3a6ef44 commit f883c7b

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

optimas/generators/ax/developer/multitask.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,20 @@ def __init__(
193193
model_save_period: Optional[int] = 5,
194194
model_history_dir: Optional[str] = "model_history",
195195
) -> None:
196+
197+
# As trial parameters these get written to history array
198+
# Ax trial_index and arm toegther locate a point
199+
# Multiple points (Optimas trials) can share the same Ax trial_index
200+
# vocs interface note: These are not part of vocs. They are only stored
201+
# to allow keeping track of them from previous runs.
202+
custom_trial_parameters = [
203+
TrialParameter("arm_name", "ax_arm_name", dtype="U32"),
204+
TrialParameter("ax_trial_id", "ax_trial_index", dtype=int),
205+
]
196206
self._check_inputs(vocs, lofi_task, hifi_task)
197207

198208
# Convert discrete variables to trial parameters before calling super().__init__
199-
custom_trial_parameters = (
209+
custom_trial_parameters.extend(
200210
self._convert_discrete_variables_to_trial_parameters(vocs)
201211
)
202212

@@ -224,10 +234,6 @@ def __init__(
224234
self.gr_lofi = None
225235
self._experiment = self._create_experiment()
226236

227-
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228-
self._id_mapping = {}
229-
self._next_id = 0
230-
231237
def _convert_discrete_variables_to_trial_parameters(
232238
self, vocs: VOCS
233239
) -> List[TrialParameter]:
@@ -302,15 +308,8 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
302308
if trial_param.name == "trial_type":
303309
point[trial_param.name] = trial_type
304310

305-
# Generate unique _id and store mapping
306-
current_id = self._next_id
307-
self._id_mapping[current_id] = {
308-
"arm_name": arm.name,
309-
"ax_trial_id": trial_index,
310-
"trial_type": trial_type,
311-
}
312-
point["_id"] = current_id
313-
self._next_id += 1
311+
point["ax_trial_id"] = trial_index
312+
point["arm_name"] = arm.name
314313
points.append(point)
315314
return points
316315

@@ -328,14 +327,6 @@ def ingest(self, results: List[dict]) -> None:
328327
)
329328
trials.append(trial)
330329

331-
# Apply _id mapping to all trials before processing
332-
for trial in trials:
333-
if trial.gen_id is not None and trial.gen_id in self._id_mapping:
334-
mapping = self._id_mapping[trial.gen_id]
335-
trial.arm_name = mapping["arm_name"]
336-
trial.ax_trial_id = mapping["ax_trial_id"]
337-
# trial_type should already be in trial from custom_parameters
338-
339330
if self.gen_state == NOT_STARTED:
340331
self._incorporate_external_data(trials)
341332
else:

0 commit comments

Comments
 (0)