67
67
TrialStatus ,
68
68
)
69
69
from .ax_metric import AxMetric
70
+ from generator_standard .vocs import VOCS
70
71
71
72
# Define generator states.
72
73
NOT_STARTED = "not_started"
@@ -152,10 +153,8 @@ class AxMultitaskGenerator(AxGenerator):
152
153
153
154
Parameters
154
155
----------
155
- varying_parameters : list of VaryingParameter
156
- List of input parameters to vary. One them should be a fidelity.
157
- objectives : list of Objective
158
- List of optimization objectives. Only one objective is supported.
156
+ vocs : VOCS
157
+ VOCS object defining variables, objectives, constraints, and observables.
159
158
lofi_task, hifi_task : Task
160
159
The low- and high-fidelity tasks.
161
160
analyzed_parameters : list of Parameter, optional
@@ -184,11 +183,9 @@ class AxMultitaskGenerator(AxGenerator):
184
183
185
184
def __init__ (
186
185
self ,
187
- varying_parameters : List [VaryingParameter ],
188
- objectives : List [Objective ],
186
+ vocs : VOCS ,
189
187
lofi_task : Task ,
190
188
hifi_task : Task ,
191
- analyzed_parameters : Optional [List [Parameter ]] = None ,
192
189
use_cuda : Optional [bool ] = False ,
193
190
gpu_id : Optional [int ] = 0 ,
194
191
dedicated_resources : Optional [bool ] = False ,
@@ -200,15 +197,11 @@ def __init__(
200
197
# Ax trial_index and arm toegther locate a point
201
198
# Multiple points (Optimas trials) can share the same Ax trial_index
202
199
custom_trial_parameters = [
203
- TrialParameter ("arm_name" , "ax_arm_name" , dtype = "U32" ),
204
200
TrialParameter ("trial_type" , "ax_trial_type" , dtype = "U32" ),
205
- TrialParameter ("ax_trial_id" , "ax_trial_index" , dtype = int ),
206
201
]
207
- self ._check_inputs (varying_parameters , objectives , lofi_task , hifi_task )
202
+ self ._check_inputs (vocs , lofi_task , hifi_task )
208
203
super ().__init__ (
209
- varying_parameters = varying_parameters ,
210
- objectives = objectives ,
211
- analyzed_parameters = analyzed_parameters ,
204
+ vocs = vocs ,
212
205
use_cuda = use_cuda ,
213
206
gpu_id = gpu_id ,
214
207
dedicated_resources = dedicated_resources ,
@@ -230,6 +223,10 @@ def __init__(
230
223
self .current_trial = None
231
224
self .gr_lofi = None
232
225
self ._experiment = self ._create_experiment ()
226
+
227
+ # Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228
+ self ._id_mapping = {}
229
+ self ._next_id = 0
233
230
234
231
def get_gen_specs (
235
232
self , sim_workers : int , run_params : Dict , sim_max : int
@@ -244,14 +241,13 @@ def get_gen_specs(
244
241
245
242
def _check_inputs (
246
243
self ,
247
- varying_parameters : List [VaryingParameter ],
248
- objectives : List [Objective ],
244
+ vocs : VOCS ,
249
245
lofi_task : Task ,
250
246
hifi_task : Task ,
251
247
) -> None :
252
248
"""Check that the generator inputs are valid."""
253
249
# Check that only one objective has been given.
254
- n_objectives = len (objectives )
250
+ n_objectives = len (vocs . objectives )
255
251
assert n_objectives == 1 , (
256
252
"Multitask generator supports only a single objective. "
257
253
"Objectives given: {}." .format (n_objectives )
@@ -274,11 +270,16 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
274
270
var .name : arm .parameters .get (var .name )
275
271
for var in self ._varying_parameters
276
272
}
277
- # SH for VOCS standard these will need to be 'variables'
278
- # For now much match the trial parameter names.
279
- point ["ax_trial_id" ] = trial_index
280
- point ["arm_name" ] = arm .name
281
- point ["trial_type" ] = trial_type
273
+ # Generate unique _id and store mapping
274
+ current_id = self ._next_id
275
+ self ._id_mapping [current_id ] = {
276
+ "arm_name" : arm .name ,
277
+ "ax_trial_id" : trial_index ,
278
+ "trial_type" : trial_type
279
+ }
280
+ point ["_id" ] = current_id
281
+ point ["trial_type" ] = trial_type # Keep trial_type for now
282
+ self ._next_id += 1
282
283
points .append (point )
283
284
return points
284
285
@@ -295,6 +296,15 @@ def ingest(self, results: List[dict]) -> None:
295
296
custom_parameters = self ._custom_trial_parameters ,
296
297
)
297
298
trials .append (trial )
299
+
300
+ # Apply _id mapping to all trials before processing
301
+ for trial in trials :
302
+ if trial .gen_id is not None and trial .gen_id in self ._id_mapping :
303
+ mapping = self ._id_mapping [trial .gen_id ]
304
+ trial .arm_name = mapping ["arm_name" ]
305
+ trial .ax_trial_id = mapping ["ax_trial_id" ]
306
+ # trial_type should already be in trial from custom_parameters
307
+
298
308
if self .gen_state == NOT_STARTED :
299
309
self ._incorporate_external_data (trials )
300
310
else :
0 commit comments