Skip to content

Commit af2b0d4

Browse files
committed
Refactor Ax service generator
* Simplify ingest function * self._ax_client.attach_trial not called on ignored trial
1 parent 8d9bc44 commit af2b0d4

File tree

1 file changed

+87
-77
lines changed
  • optimas/generators/ax/service

1 file changed

+87
-77
lines changed

optimas/generators/ax/service/base.py

Lines changed: 87 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Trial,
2424
VaryingParameter,
2525
Parameter,
26-
TrialStatus,
2726
)
2827
from optimas.generators.ax.base import AxGenerator
2928
from optimas.utils.ax import AxModelManager
@@ -157,6 +156,31 @@ def suggest(self, num_points: Optional[int]) -> List[dict]:
157156
points.append(point)
158157
return points
159158

159+
def ingest(self, results: List[dict]) -> None:
160+
"""Send the results of evaluations to the generator."""
161+
for result in results:
162+
trial = Trial.from_dict(
163+
trial_dict=result,
164+
varying_parameters=self._varying_parameters,
165+
objectives=self._objectives,
166+
analyzed_parameters=self._analyzed_parameters,
167+
custom_parameters=self._custom_trial_parameters,
168+
)
169+
if trial.ignored:
170+
continue
171+
try:
172+
ax_trial = self._ax_client.get_trial(trial.ax_trial_id)
173+
except AttributeError:
174+
ax_trial = self._insert_unknown_trial(trial)
175+
finally:
176+
if trial.completed:
177+
self._complete_trial(ax_trial.index, trial)
178+
elif trial.failed:
179+
if self._abandon_failed_trials:
180+
ax_trial.mark_abandoned()
181+
else:
182+
ax_trial.mark_failed()
183+
160184
def _ignore_out_of_bounds(self, trial: Trial) -> None:
161185
"""Check if trial parameters are within their bounds."""
162186
for var, value in zip(trial.varying_parameters, trial.parameter_values):
@@ -177,83 +201,69 @@ def ignore_trials(self, trials: List[Trial]) -> None:
177201
# Handle unknown trial
178202
self._ignore_out_of_bounds(trial)
179203

180-
def ingest(self, results: List[dict]) -> None:
181-
"""Send the results of evaluations to the generator."""
182-
for result in results:
183-
trial = Trial.from_dict(
184-
trial_dict=result,
185-
varying_parameters=self._varying_parameters,
186-
objectives=self._objectives,
187-
analyzed_parameters=self._analyzed_parameters,
188-
custom_parameters=self._custom_trial_parameters,
189-
)
190-
try:
191-
trial_id = trial.ax_trial_id
192-
ax_trial = self._ax_client.get_trial(trial_id)
193-
except AttributeError:
194-
params = {}
195-
for var, value in zip(
196-
trial.varying_parameters, trial.parameter_values
197-
):
198-
params[var.name] = value
199-
try:
200-
_, trial_id = self._ax_client.attach_trial(params)
201-
except ValueError as error:
202-
# Bypass checks from AxClient and manually add a trial
203-
# outside of the search space.
204-
# https://github.com/facebook/Ax/issues/768#issuecomment-1036515242
205-
if "not a valid value" in str(error):
206-
if self._fit_out_of_design:
207-
ax_trial = self._ax_client.experiment.new_trial()
208-
ax_trial.add_arm(Arm(parameters=params))
209-
ax_trial.mark_running(no_runner_required=True)
210-
trial_id = ax_trial.index
211-
else:
212-
raise error
213-
ax_trial = self._ax_client.get_trial(trial_id)
204+
def _get_ingest_params(self, trial: Trial) -> Dict:
205+
"""Return a trials ingest parameters as a dictionary."""
206+
params = {}
207+
for var, value in zip(trial.varying_parameters, trial.parameter_values):
208+
params[var.name] = value
209+
return params
214210

215-
# Since data was given externally, reduce number of
216-
# initialization trials, but only if they have not failed.
217-
if trial.completed and not self._enforce_n_init:
218-
generation_strategy = self._ax_client.generation_strategy
219-
current_step = generation_strategy.current_step
220-
# Reduce only if there are still Sobol trials left.
221-
if current_step.model == Models.SOBOL:
222-
for tc in current_step.transition_criteria:
223-
# Looping over all criterial makes sure we reduce
224-
# the transition thresholds due to `_n_init`
225-
# (i.e., max trials) and `min_trials_observed=1` (
226-
# i.e., min trials).
227-
if isinstance(tc, (MinTrials, MaxTrials)):
228-
tc.threshold -= 1
229-
generation_strategy._maybe_transition_to_next_node()
230-
finally:
231-
if trial.ignored:
232-
continue
233-
elif trial.completed:
234-
outcome_evals = {}
235-
# Add objective evaluations.
236-
for ev in trial.objective_evaluations:
237-
outcome_evals[ev.parameter.name] = (ev.value, ev.sem)
238-
# Add outcome constraints evaluations.
239-
ax_config = self._ax_client.experiment.optimization_config
240-
if ax_config.outcome_constraints:
241-
ocs = [
242-
oc.metric.name
243-
for oc in ax_config.outcome_constraints
244-
]
245-
for ev in trial.parameter_evaluations:
246-
par_name = ev.parameter.name
247-
if par_name in ocs:
248-
outcome_evals[par_name] = (ev.value, ev.sem)
249-
self._ax_client.complete_trial(
250-
trial_index=trial_id, raw_data=outcome_evals
251-
)
252-
elif trial.failed:
253-
if self._abandon_failed_trials:
254-
ax_trial.mark_abandoned()
255-
else:
256-
ax_trial.mark_failed()
211+
def _insert_unknown_trial(self, trial: Trial) -> None:
212+
"""Insert an unknown trial into the Ax client."""
213+
params = self._get_ingest_params(trial)
214+
try:
215+
_, trial_id = self._ax_client.attach_trial(params)
216+
except ValueError as error:
217+
# Bypass checks from AxClient and manually add a trial
218+
# outside of the search space.
219+
# https://github.com/facebook/Ax/issues/768#issuecomment-1036515242
220+
if "not a valid value" in str(error):
221+
if self._fit_out_of_design:
222+
ax_trial = self._ax_client.experiment.new_trial()
223+
ax_trial.add_arm(Arm(parameters=params))
224+
ax_trial.mark_running(no_runner_required=True)
225+
trial_id = ax_trial.index
226+
else:
227+
raise error
228+
ax_trial = self._ax_client.get_trial(trial_id)
229+
230+
# Since data was given externally, reduce number of
231+
# initialization trials, but only if they have not failed.
232+
if trial.completed and not self._enforce_n_init:
233+
generation_strategy = self._ax_client.generation_strategy
234+
current_step = generation_strategy.current_step
235+
# Reduce only if there are still Sobol trials left.
236+
if current_step.model == Models.SOBOL:
237+
for tc in current_step.transition_criteria:
238+
# Looping over all criterial makes sure we reduce
239+
# the transition thresholds due to `_n_init`
240+
# (i.e., max trials) and `min_trials_observed=1` (
241+
# i.e., min trials).
242+
if isinstance(tc, (MinTrials, MaxTrials)):
243+
tc.threshold -= 1
244+
generation_strategy._maybe_transition_to_next_node()
245+
return ax_trial
246+
247+
def _complete_trial(self, ax_trial_index: int, trial: Trial) -> None:
248+
"""Complete a trial."""
249+
outcome_evals = {}
250+
# Add objective evaluations.
251+
for ev in trial.objective_evaluations:
252+
outcome_evals[ev.parameter.name] = (ev.value, ev.sem)
253+
# Add outcome constraints evaluations.
254+
ax_config = self._ax_client.experiment.optimization_config
255+
if ax_config.outcome_constraints:
256+
ocs = [
257+
oc.metric.name
258+
for oc in ax_config.outcome_constraints
259+
]
260+
for ev in trial.parameter_evaluations:
261+
par_name = ev.parameter.name
262+
if par_name in ocs:
263+
outcome_evals[par_name] = (ev.value, ev.sem)
264+
self._ax_client.complete_trial(
265+
trial_index=ax_trial_index, raw_data=outcome_evals
266+
)
257267

258268
def _create_ax_client(self) -> AxClient:
259269
"""Create Ax client."""

0 commit comments

Comments
 (0)