23
23
Trial ,
24
24
VaryingParameter ,
25
25
Parameter ,
26
- TrialStatus ,
27
26
)
28
27
from optimas .generators .ax .base import AxGenerator
29
28
from optimas .utils .ax import AxModelManager
@@ -157,6 +156,31 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
157
156
points .append (point )
158
157
return points
159
158
159
+ def ingest (self , results : List [dict ]) -> None :
160
+ """Send the results of evaluations to the generator."""
161
+ for result in results :
162
+ trial = Trial .from_dict (
163
+ trial_dict = result ,
164
+ varying_parameters = self ._varying_parameters ,
165
+ objectives = self ._objectives ,
166
+ analyzed_parameters = self ._analyzed_parameters ,
167
+ custom_parameters = self ._custom_trial_parameters ,
168
+ )
169
+ if trial .ignored :
170
+ continue
171
+ try :
172
+ ax_trial = self ._ax_client .get_trial (trial .ax_trial_id )
173
+ except AttributeError :
174
+ ax_trial = self ._insert_unknown_trial (trial )
175
+ finally :
176
+ if trial .completed :
177
+ self ._complete_trial (ax_trial .index , trial )
178
+ elif trial .failed :
179
+ if self ._abandon_failed_trials :
180
+ ax_trial .mark_abandoned ()
181
+ else :
182
+ ax_trial .mark_failed ()
183
+
160
184
def _ignore_out_of_bounds (self , trial : Trial ) -> None :
161
185
"""Check if trial parameters are within their bounds."""
162
186
for var , value in zip (trial .varying_parameters , trial .parameter_values ):
@@ -177,83 +201,69 @@ def ignore_trials(self, trials: List[Trial]) -> None:
177
201
# Handle unknown trial
178
202
self ._ignore_out_of_bounds (trial )
179
203
180
- def ingest (self , results : List [dict ]) -> None :
181
- """Send the results of evaluations to the generator."""
182
- for result in results :
183
- trial = Trial .from_dict (
184
- trial_dict = result ,
185
- varying_parameters = self ._varying_parameters ,
186
- objectives = self ._objectives ,
187
- analyzed_parameters = self ._analyzed_parameters ,
188
- custom_parameters = self ._custom_trial_parameters ,
189
- )
190
- try :
191
- trial_id = trial .ax_trial_id
192
- ax_trial = self ._ax_client .get_trial (trial_id )
193
- except AttributeError :
194
- params = {}
195
- for var , value in zip (
196
- trial .varying_parameters , trial .parameter_values
197
- ):
198
- params [var .name ] = value
199
- try :
200
- _ , trial_id = self ._ax_client .attach_trial (params )
201
- except ValueError as error :
202
- # Bypass checks from AxClient and manually add a trial
203
- # outside of the search space.
204
- # https://github.com/facebook/Ax/issues/768#issuecomment-1036515242
205
- if "not a valid value" in str (error ):
206
- if self ._fit_out_of_design :
207
- ax_trial = self ._ax_client .experiment .new_trial ()
208
- ax_trial .add_arm (Arm (parameters = params ))
209
- ax_trial .mark_running (no_runner_required = True )
210
- trial_id = ax_trial .index
211
- else :
212
- raise error
213
- ax_trial = self ._ax_client .get_trial (trial_id )
204
+ def _get_ingest_params (self , trial : Trial ) -> Dict :
205
+ """Return a trials ingest parameters as a dictionary."""
206
+ params = {}
207
+ for var , value in zip (trial .varying_parameters , trial .parameter_values ):
208
+ params [var .name ] = value
209
+ return params
214
210
215
- # Since data was given externally, reduce number of
216
- # initialization trials, but only if they have not failed.
217
- if trial .completed and not self ._enforce_n_init :
218
- generation_strategy = self ._ax_client .generation_strategy
219
- current_step = generation_strategy .current_step
220
- # Reduce only if there are still Sobol trials left.
221
- if current_step .model == Models .SOBOL :
222
- for tc in current_step .transition_criteria :
223
- # Looping over all criterial makes sure we reduce
224
- # the transition thresholds due to `_n_init`
225
- # (i.e., max trials) and `min_trials_observed=1` (
226
- # i.e., min trials).
227
- if isinstance (tc , (MinTrials , MaxTrials )):
228
- tc .threshold -= 1
229
- generation_strategy ._maybe_transition_to_next_node ()
230
- finally :
231
- if trial .ignored :
232
- continue
233
- elif trial .completed :
234
- outcome_evals = {}
235
- # Add objective evaluations.
236
- for ev in trial .objective_evaluations :
237
- outcome_evals [ev .parameter .name ] = (ev .value , ev .sem )
238
- # Add outcome constraints evaluations.
239
- ax_config = self ._ax_client .experiment .optimization_config
240
- if ax_config .outcome_constraints :
241
- ocs = [
242
- oc .metric .name
243
- for oc in ax_config .outcome_constraints
244
- ]
245
- for ev in trial .parameter_evaluations :
246
- par_name = ev .parameter .name
247
- if par_name in ocs :
248
- outcome_evals [par_name ] = (ev .value , ev .sem )
249
- self ._ax_client .complete_trial (
250
- trial_index = trial_id , raw_data = outcome_evals
251
- )
252
- elif trial .failed :
253
- if self ._abandon_failed_trials :
254
- ax_trial .mark_abandoned ()
255
- else :
256
- ax_trial .mark_failed ()
211
+ def _insert_unknown_trial (self , trial : Trial ) -> None :
212
+ """Insert an unknown trial into the Ax client."""
213
+ params = self ._get_ingest_params (trial )
214
+ try :
215
+ _ , trial_id = self ._ax_client .attach_trial (params )
216
+ except ValueError as error :
217
+ # Bypass checks from AxClient and manually add a trial
218
+ # outside of the search space.
219
+ # https://github.com/facebook/Ax/issues/768#issuecomment-1036515242
220
+ if "not a valid value" in str (error ):
221
+ if self ._fit_out_of_design :
222
+ ax_trial = self ._ax_client .experiment .new_trial ()
223
+ ax_trial .add_arm (Arm (parameters = params ))
224
+ ax_trial .mark_running (no_runner_required = True )
225
+ trial_id = ax_trial .index
226
+ else :
227
+ raise error
228
+ ax_trial = self ._ax_client .get_trial (trial_id )
229
+
230
+ # Since data was given externally, reduce number of
231
+ # initialization trials, but only if they have not failed.
232
+ if trial .completed and not self ._enforce_n_init :
233
+ generation_strategy = self ._ax_client .generation_strategy
234
+ current_step = generation_strategy .current_step
235
+ # Reduce only if there are still Sobol trials left.
236
+ if current_step .model == Models .SOBOL :
237
+ for tc in current_step .transition_criteria :
238
+ # Looping over all criterial makes sure we reduce
239
+ # the transition thresholds due to `_n_init`
240
+ # (i.e., max trials) and `min_trials_observed=1` (
241
+ # i.e., min trials).
242
+ if isinstance (tc , (MinTrials , MaxTrials )):
243
+ tc .threshold -= 1
244
+ generation_strategy ._maybe_transition_to_next_node ()
245
+ return ax_trial
246
+
247
+ def _complete_trial (self , ax_trial_index : int , trial : Trial ) -> None :
248
+ """Complete a trial."""
249
+ outcome_evals = {}
250
+ # Add objective evaluations.
251
+ for ev in trial .objective_evaluations :
252
+ outcome_evals [ev .parameter .name ] = (ev .value , ev .sem )
253
+ # Add outcome constraints evaluations.
254
+ ax_config = self ._ax_client .experiment .optimization_config
255
+ if ax_config .outcome_constraints :
256
+ ocs = [
257
+ oc .metric .name
258
+ for oc in ax_config .outcome_constraints
259
+ ]
260
+ for ev in trial .parameter_evaluations :
261
+ par_name = ev .parameter .name
262
+ if par_name in ocs :
263
+ outcome_evals [par_name ] = (ev .value , ev .sem )
264
+ self ._ax_client .complete_trial (
265
+ trial_index = ax_trial_index , raw_data = outcome_evals
266
+ )
257
267
258
268
def _create_ax_client (self ) -> AxClient :
259
269
"""Create Ax client."""
0 commit comments