Skip to content

Commit bfc4fad

Browse files
committed
Move test code to match run order
1 parent b7ae657 commit bfc4fad

File tree

1 file changed

+69
-69
lines changed

1 file changed

+69
-69
lines changed

tests/test_ax_generators.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,75 @@ def test_ax_single_fidelity():
199199
make_plots(gen)
200200

201201

202+
def test_ax_single_fidelity_resume():
203+
"""
204+
Test that an exploration with an AxService generator can resume
205+
with an updated range of the varying parameters, even if some
206+
old trials are out of the updated range.
207+
"""
208+
global trial_count
209+
global trials_to_fail
210+
trial_count = 0
211+
trials_to_fail = []
212+
213+
fit_out_of_design_vals = [False, True]
214+
215+
for fit_out_of_design in fit_out_of_design_vals:
216+
var1 = VaryingParameter("x0", 5.1, 6.0)
217+
var2 = VaryingParameter("x1", -5.0, 15.0)
218+
obj = Objective("f", minimize=False)
219+
p1 = Parameter("p1")
220+
221+
gen = AxSingleFidelityGenerator(
222+
varying_parameters=[var1, var2],
223+
objectives=[obj],
224+
analyzed_parameters=[p1],
225+
parameter_constraints=["x0 + x1 <= 10"],
226+
outcome_constraints=["p1 <= 30"],
227+
fit_out_of_design=fit_out_of_design,
228+
)
229+
ev = FunctionEvaluator(function=eval_func_sf)
230+
exploration = Exploration(
231+
generator=gen,
232+
evaluator=ev,
233+
max_evals=20,
234+
sim_workers=2,
235+
exploration_dir_path="./tests_output/test_ax_single_fidelity",
236+
libe_comms="local_threading",
237+
resume=True,
238+
)
239+
240+
# Get reference to original AxClient.
241+
ax_client = gen._ax_client
242+
243+
# Run exploration.
244+
exploration.run(n_evals=1)
245+
246+
if not fit_out_of_design:
247+
# Check that no old evaluations were added
248+
assert (
249+
len(exploration.history) == 11
250+
), f"Got: {len(exploration.history)}"
251+
assert all(exploration.history.trial_ignored.to_numpy()[:-1])
252+
# Check that the sobol step has not been skipped.
253+
df = ax_client.get_trials_data_frame()
254+
assert len(df) == 1
255+
assert df["generation_method"].to_numpy()[0] == "Sobol"
256+
257+
else:
258+
# Check that the old evaluations were added
259+
assert len(exploration.history) == 12
260+
assert not all(exploration.history.trial_ignored.to_numpy())
261+
# Check that the sobol step has been skipped.
262+
df = ax_client.get_trials_data_frame()
263+
assert len(df) == 12
264+
assert df["generation_method"].to_numpy()[-1] == "BoTorch"
265+
266+
check_run_ax_service(
267+
ax_client, gen, exploration, n_failed_expected=2
268+
)
269+
270+
202271
def test_ax_single_fidelity_int():
203272
"""
204273
Test that an exploration with a single-fidelity generator runs
@@ -428,75 +497,6 @@ def test_ax_single_fidelity_updated_params():
428497
make_plots(gen)
429498

430499

431-
def test_ax_single_fidelity_resume():
432-
"""
433-
Test that an exploration with an AxService generator can resume
434-
with an updated range of the varying parameters, even if some
435-
old trials are out of the updated range.
436-
"""
437-
global trial_count
438-
global trials_to_fail
439-
trial_count = 0
440-
trials_to_fail = []
441-
442-
fit_out_of_design_vals = [False, True]
443-
444-
for fit_out_of_design in fit_out_of_design_vals:
445-
var1 = VaryingParameter("x0", 5.1, 6.0)
446-
var2 = VaryingParameter("x1", -5.0, 15.0)
447-
obj = Objective("f", minimize=False)
448-
p1 = Parameter("p1")
449-
450-
gen = AxSingleFidelityGenerator(
451-
varying_parameters=[var1, var2],
452-
objectives=[obj],
453-
analyzed_parameters=[p1],
454-
parameter_constraints=["x0 + x1 <= 10"],
455-
outcome_constraints=["p1 <= 30"],
456-
fit_out_of_design=fit_out_of_design,
457-
)
458-
ev = FunctionEvaluator(function=eval_func_sf)
459-
exploration = Exploration(
460-
generator=gen,
461-
evaluator=ev,
462-
max_evals=20,
463-
sim_workers=2,
464-
exploration_dir_path="./tests_output/test_ax_single_fidelity",
465-
libe_comms="local_threading",
466-
resume=True,
467-
)
468-
469-
# Get reference to original AxClient.
470-
ax_client = gen._ax_client
471-
472-
# Run exploration.
473-
exploration.run(n_evals=1)
474-
475-
if not fit_out_of_design:
476-
# Check that no old evaluations were added
477-
assert (
478-
len(exploration.history) == 11
479-
), f"Got: {len(exploration.history)}"
480-
assert all(exploration.history.trial_ignored.to_numpy()[:-1])
481-
# Check that the sobol step has not been skipped.
482-
df = ax_client.get_trials_data_frame()
483-
assert len(df) == 1
484-
assert df["generation_method"].to_numpy()[0] == "Sobol"
485-
486-
else:
487-
# Check that the old evaluations were added
488-
assert len(exploration.history) == 12
489-
assert not all(exploration.history.trial_ignored.to_numpy())
490-
# Check that the sobol step has been skipped.
491-
df = ax_client.get_trials_data_frame()
492-
assert len(df) == 12
493-
assert df["generation_method"].to_numpy()[-1] == "BoTorch"
494-
495-
check_run_ax_service(
496-
ax_client, gen, exploration, n_failed_expected=2
497-
)
498-
499-
500500
def test_ax_multi_fidelity():
501501
"""Test that an exploration with a multifidelity generator runs"""
502502

0 commit comments

Comments
 (0)