@@ -105,6 +105,12 @@ def __init__(
105
105
self ._custom_trial_parameters = (
106
106
[] if custom_trial_parameters is None else custom_trial_parameters
107
107
)
108
+
109
+ # Automatically add discrete variables as trial parameters
110
+ discrete_trial_params = (
111
+ self ._convert_vocs_discrete_variables_to_trial_parameters ()
112
+ )
113
+ self ._custom_trial_parameters .extend (discrete_trial_params )
108
114
self ._allow_fixed_parameters = allow_fixed_parameters
109
115
self ._allow_updating_parameters = allow_updating_parameters
110
116
self ._gen_function = persistent_generator
@@ -191,6 +197,31 @@ def _convert_vocs_observables_to_parameters(self) -> List[Parameter]:
191
197
parameters .append (param )
192
198
return parameters
193
199
200
+ def _convert_vocs_discrete_variables_to_trial_parameters (
201
+ self ,
202
+ ) -> List [TrialParameter ]:
203
+ """Convert discrete variables from VOCS to TrialParameter objects.
204
+
205
+ Only converts discrete variables that were NOT already converted to
206
+ VaryingParameters.
207
+ """
208
+ trial_parameters = []
209
+ # Get the names of variables that were already converted to
210
+ # VaryingParameters
211
+ varying_param_names = {vp .name for vp in self ._varying_parameters }
212
+
213
+ for var_name , var_spec in self ._vocs .variables .items ():
214
+ if isinstance (var_spec , DiscreteVariable ):
215
+ # Only convert if it wasn't already converted to a
216
+ # VaryingParameter
217
+ if var_name not in varying_param_names :
218
+ max_len = max (len (str (val )) for val in var_spec .values )
219
+ trial_param = TrialParameter (
220
+ var_name , var_name , dtype = f"U{ max_len } "
221
+ )
222
+ trial_parameters .append (trial_param )
223
+ return trial_parameters
224
+
194
225
@property
195
226
def vocs (self ) -> VOCS :
196
227
"""Get the VOCS object."""
0 commit comments