Skip to content

Commit 4be14a7

Browse files
committed
Change gen strategy from GPEI to BoTorch
1 parent 6766fa3 commit 4be14a7

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

tests/test_ax_generators.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def test_ax_single_fidelity_resume():
488488
# Check that the sobol step has been skipped.
489489
df = ax_client.get_trials_data_frame()
490490
assert len(df) == 12
491-
assert df["generation_method"].to_numpy()[-1] == "GPEI"
491+
assert df["generation_method"].to_numpy()[-1] == "BoTorch"
492492

493493
check_run_ax_service(
494494
ax_client, gen, exploration, n_failed_expected=2
@@ -821,7 +821,8 @@ def test_ax_service_init():
821821
assert df["generation_method"][j] == "Manual"
822822
for k in range(i, n_init - 1):
823823
assert df["generation_method"][k] == "Sobol"
824-
df["generation_method"][min(i, n_init)] == "GPEI"
824+
825+
df["generation_method"][min(i, n_init)] == "BoTorch"
825826

826827
# Try to load saved client from json. This used to fail when the SOBOL
827828
# step was skipped due to n_external > n_init. It is added here to prevent
@@ -870,21 +871,30 @@ def test_ax_service_init():
870871
assert df["generation_method"][j] == "Manual"
871872
for k in range(n_external, n_external + n_init):
872873
assert df["generation_method"][k] == "Sobol"
873-
df["generation_method"][n_external + n_init] == "GPEI"
874+
df["generation_method"][n_external + n_init] == "BoTorch"
874875

875876

876877
if __name__ == "__main__":
877-
test_ax_single_fidelity()
878-
test_ax_single_fidelity_resume()
879-
test_ax_single_fidelity_int()
880-
test_ax_single_fidelity_moo()
881-
test_ax_single_fidelity_fb()
882-
test_ax_single_fidelity_moo_fb()
883-
test_ax_single_fidelity_updated_params()
884-
test_ax_multi_fidelity()
885-
test_ax_multitask()
886-
test_ax_client()
887-
test_ax_single_fidelity_with_history()
888-
test_ax_multi_fidelity_with_history()
889-
test_ax_multitask_with_history()
890-
test_ax_service_init()
878+
tests = [
879+
test_ax_single_fidelity,
880+
test_ax_single_fidelity_resume,
881+
test_ax_single_fidelity_int,
882+
test_ax_single_fidelity_moo,
883+
test_ax_single_fidelity_fb,
884+
test_ax_single_fidelity_moo_fb,
885+
test_ax_single_fidelity_updated_params,
886+
test_ax_multi_fidelity,
887+
test_ax_multitask,
888+
test_ax_client,
889+
test_ax_single_fidelity_with_history,
890+
test_ax_multi_fidelity_with_history,
891+
test_ax_multitask_with_history,
892+
test_ax_service_init,
893+
]
894+
895+
for test in tests:
896+
print()
897+
print("-" * len(test.__name__))
898+
print(test.__name__) # Print test name
899+
print("-" * len(test.__name__))
900+
test() # Run the test

0 commit comments

Comments
 (0)