67
67
TrialStatus ,
68
68
)
69
69
from .ax_metric import AxMetric
70
- from generator_standard .vocs import VOCS
70
+ from generator_standard .vocs import VOCS , DiscreteVariable
71
71
72
72
# Define generator states.
73
73
NOT_STARTED = "not_started"
@@ -193,13 +193,11 @@ def __init__(
193
193
model_save_period : Optional [int ] = 5 ,
194
194
model_history_dir : Optional [str ] = "model_history" ,
195
195
) -> None :
196
- # As trial parameters these get written to history array
197
- # Ax trial_index and arm toegther locate a point
198
- # Multiple points (Optimas trials) can share the same Ax trial_index
199
- custom_trial_parameters = [
200
- TrialParameter ("trial_type" , "ax_trial_type" , dtype = "U32" ),
201
- ]
202
196
self ._check_inputs (vocs , lofi_task , hifi_task )
197
+
198
+ # Convert discrete variables to trial parameters before calling super().__init__
199
+ custom_trial_parameters = self ._convert_discrete_variables_to_trial_parameters (vocs )
200
+
203
201
super ().__init__ (
204
202
vocs = vocs ,
205
203
use_cuda = use_cuda ,
@@ -223,11 +221,22 @@ def __init__(
223
221
self .current_trial = None
224
222
self .gr_lofi = None
225
223
self ._experiment = self ._create_experiment ()
226
-
224
+
227
225
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
228
226
self ._id_mapping = {}
229
227
self ._next_id = 0
230
228
229
+ def _convert_discrete_variables_to_trial_parameters (self , vocs : VOCS ) -> List [TrialParameter ]:
230
+ """Convert discrete variables from VOCS to TrialParameter objects."""
231
+ trial_parameters = []
232
+ for var_name , var_spec in vocs .variables .items ():
233
+ if isinstance (var_spec , DiscreteVariable ):
234
+ # Convert discrete variable to trial parameter
235
+ max_len = max (len (str (val )) for val in var_spec .values )
236
+ trial_param = TrialParameter (var_name , var_name , dtype = f"U{ max_len } " )
237
+ trial_parameters .append (trial_param )
238
+ return trial_parameters
239
+
231
240
def get_gen_specs (
232
241
self , sim_workers : int , run_params : Dict , sim_max : int
233
242
) -> Dict :
@@ -239,19 +248,30 @@ def get_gen_specs(
239
248
gen_specs ["out" ].append (("task" , str , max_length ))
240
249
return gen_specs
241
250
251
+ def _validate_vocs (self , vocs : VOCS ) -> None :
252
+ """Validate VOCS for multitask generator."""
253
+ super ()._validate_vocs (vocs )
254
+ # Check that only one objective has been given.
255
+ n_objectives = len (vocs .objectives )
256
+ assert n_objectives == 1 , (
257
+ "Multitask generator supports only a single objective. "
258
+ "Objectives given: {}." .format (n_objectives )
259
+ )
260
+ # Check that there is a discrete variable called 'trial_type'
261
+ assert "trial_type" in vocs .variables , (
262
+ "Multitask generator requires a discrete variable named 'trial_type'"
263
+ )
264
+ assert isinstance (vocs .variables ["trial_type" ], DiscreteVariable ), (
265
+ "Variable 'trial_type' must be a discrete variable"
266
+ )
267
+
242
268
def _check_inputs (
243
269
self ,
244
270
vocs : VOCS ,
245
271
lofi_task : Task ,
246
272
hifi_task : Task ,
247
273
) -> None :
248
274
"""Check that the generator inputs are valid."""
249
- # Check that only one objective has been given.
250
- n_objectives = len (vocs .objectives )
251
- assert n_objectives == 1 , (
252
- "Multitask generator supports only a single objective. "
253
- "Objectives given: {}." .format (n_objectives )
254
- )
255
275
# Check that the number of low-fidelity trials per iteration is larger
256
276
# than that of high-fidelity trials.
257
277
assert lofi_task .n_opt >= hifi_task .n_opt , (
@@ -270,6 +290,12 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
270
290
var .name : arm .parameters .get (var .name )
271
291
for var in self ._varying_parameters
272
292
}
293
+ # SH We can use a discrete var here in vocs (converted for now to trial parameters)
294
+ # But unlike varying parameters the name refers to a fixed generator concept.
295
+ for trial_param in self ._custom_trial_parameters :
296
+ if trial_param .name == "trial_type" :
297
+ point [trial_param .name ] = trial_type
298
+
273
299
# Generate unique _id and store mapping
274
300
current_id = self ._next_id
275
301
self ._id_mapping [current_id ] = {
@@ -278,7 +304,6 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
278
304
"trial_type" : trial_type ,
279
305
}
280
306
point ["_id" ] = current_id
281
- point ["trial_type" ] = trial_type # Keep trial_type for now
282
307
self ._next_id += 1
283
308
points .append (point )
284
309
return points
0 commit comments