@@ -193,10 +193,20 @@ def __init__(
193
193
model_save_period : Optional [int ] = 5 ,
194
194
model_history_dir : Optional [str ] = "model_history" ,
195
195
) -> None :
196
+
197
+ # As trial parameters these get written to history array
198
+ # Ax trial_index and arm toegther locate a point
199
+ # Multiple points (Optimas trials) can share the same Ax trial_index
200
+ # vocs interface note: These are not part of vocs. They are only stored
201
+ # to allow keeping track of them from previous runs.
202
+ custom_trial_parameters = [
203
+ TrialParameter ("arm_name" , "ax_arm_name" , dtype = "U32" ),
204
+ TrialParameter ("ax_trial_id" , "ax_trial_index" , dtype = int ),
205
+ ]
196
206
self ._check_inputs (vocs , lofi_task , hifi_task )
197
207
198
208
# Convert discrete variables to trial parameters before calling super().__init__
199
- custom_trial_parameters = (
209
+ custom_trial_parameters . extend (
200
210
self ._convert_discrete_variables_to_trial_parameters (vocs )
201
211
)
202
212
@@ -224,10 +234,6 @@ def __init__(
224
234
self .gr_lofi = None
225
235
self ._experiment = self ._create_experiment ()
226
236
227
- # Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228
- self ._id_mapping = {}
229
- self ._next_id = 0
230
-
231
237
def _convert_discrete_variables_to_trial_parameters (
232
238
self , vocs : VOCS
233
239
) -> List [TrialParameter ]:
@@ -302,15 +308,8 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
302
308
if trial_param .name == "trial_type" :
303
309
point [trial_param .name ] = trial_type
304
310
305
- # Generate unique _id and store mapping
306
- current_id = self ._next_id
307
- self ._id_mapping [current_id ] = {
308
- "arm_name" : arm .name ,
309
- "ax_trial_id" : trial_index ,
310
- "trial_type" : trial_type ,
311
- }
312
- point ["_id" ] = current_id
313
- self ._next_id += 1
311
+ point ["ax_trial_id" ] = trial_index
312
+ point ["arm_name" ] = arm .name
314
313
points .append (point )
315
314
return points
316
315
@@ -328,14 +327,6 @@ def ingest(self, results: List[dict]) -> None:
328
327
)
329
328
trials .append (trial )
330
329
331
- # Apply _id mapping to all trials before processing
332
- for trial in trials :
333
- if trial .gen_id is not None and trial .gen_id in self ._id_mapping :
334
- mapping = self ._id_mapping [trial .gen_id ]
335
- trial .arm_name = mapping ["arm_name" ]
336
- trial .ax_trial_id = mapping ["ax_trial_id" ]
337
- # trial_type should already be in trial from custom_parameters
338
-
339
330
if self .gen_state == NOT_STARTED :
340
331
self ._incorporate_external_data (trials )
341
332
else :
0 commit comments