@@ -194,10 +194,12 @@ def __init__(
194
194
model_history_dir : Optional [str ] = "model_history" ,
195
195
) -> None :
196
196
self ._check_inputs (vocs , lofi_task , hifi_task )
197
-
197
+
198
198
# Convert discrete variables to trial parameters before calling super().__init__
199
- custom_trial_parameters = self ._convert_discrete_variables_to_trial_parameters (vocs )
200
-
199
+ custom_trial_parameters = (
200
+ self ._convert_discrete_variables_to_trial_parameters (vocs )
201
+ )
202
+
201
203
super ().__init__ (
202
204
vocs = vocs ,
203
205
use_cuda = use_cuda ,
@@ -221,19 +223,23 @@ def __init__(
221
223
self .current_trial = None
222
224
self .gr_lofi = None
223
225
self ._experiment = self ._create_experiment ()
224
-
226
+
225
227
# Internal mapping: _id -> (arm_name, ax_trial_id, trial_type)
226
228
self ._id_mapping = {}
227
229
self ._next_id = 0
228
230
229
- def _convert_discrete_variables_to_trial_parameters (self , vocs : VOCS ) -> List [TrialParameter ]:
231
+ def _convert_discrete_variables_to_trial_parameters (
232
+ self , vocs : VOCS
233
+ ) -> List [TrialParameter ]:
230
234
"""Convert discrete variables from VOCS to TrialParameter objects."""
231
235
trial_parameters = []
232
236
for var_name , var_spec in vocs .variables .items ():
233
237
if isinstance (var_spec , DiscreteVariable ):
234
238
# Convert discrete variable to trial parameter
235
239
max_len = max (len (str (val )) for val in var_spec .values )
236
- trial_param = TrialParameter (var_name , var_name , dtype = f"U{ max_len } " )
240
+ trial_param = TrialParameter (
241
+ var_name , var_name , dtype = f"U{ max_len } "
242
+ )
237
243
trial_parameters .append (trial_param )
238
244
return trial_parameters
239
245
@@ -258,12 +264,12 @@ def _validate_vocs(self, vocs: VOCS) -> None:
258
264
"Objectives given: {}." .format (n_objectives )
259
265
)
260
266
# Check that there is a discrete variable called 'trial_type'
261
- assert "trial_type" in vocs . variables , (
262
- "Multitask generator requires a discrete variable named ' trial_type'"
263
- )
264
- assert isinstance (vocs . variables [ "trial_type" ], DiscreteVariable ), (
265
- "Variable ' trial_type' must be a discrete variable"
266
- )
267
+ assert (
268
+ "trial_type" in vocs . variables
269
+ ), "Multitask generator requires a discrete variable named 'trial_type'"
270
+ assert isinstance (
271
+ vocs . variables [ " trial_type" ], DiscreteVariable
272
+ ), "Variable 'trial_type' must be a discrete variable"
267
273
268
274
def _check_inputs (
269
275
self ,
@@ -295,7 +301,7 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
295
301
for trial_param in self ._custom_trial_parameters :
296
302
if trial_param .name == "trial_type" :
297
303
point [trial_param .name ] = trial_type
298
-
304
+
299
305
# Generate unique _id and store mapping
300
306
current_id = self ._next_id
301
307
self ._id_mapping [current_id ] = {
0 commit comments