Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from ax.adapter.registry import GeneratorRegistryBase, Generators
from ax.core.experiment import Experiment
from ax.core.experiment_status import ExperimentStatus
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
Expand Down Expand Up @@ -66,6 +67,7 @@ def _make_sobol_step(
generator_kwargs={"deduplicate": True, "seed": seed},
use_all_trials_in_exp=True,
should_deduplicate=should_deduplicate,
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
)


Expand Down Expand Up @@ -133,6 +135,7 @@ def _make_botorch_step(
max_parallelism=max_concurrency,
generator_kwargs=generator_kwargs,
should_deduplicate=should_deduplicate,
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
)


Expand Down
30 changes: 30 additions & 0 deletions ax/generation_strategy/tests/test_dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from ax.adapter.registry import Generators
from ax.adapter.transforms.log_y import LogY
from ax.core.experiment_status import ExperimentStatus
from ax.core.objective import Objective
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.generation_strategy.dispatch_utils import (
Expand Down Expand Up @@ -1088,3 +1089,32 @@ def test_non_saasbo_discards_irrelevant_generator_kwargs(self) -> None:
run_branin_experiment_with_generation_strategy(
generation_strategy=gp_saasbo,
)

def test_suggested_experiment_status(self) -> None:
with self.subTest(
"choose_generation_strategy_legacy sets statuses on both nodes"
):
gs = choose_generation_strategy_legacy(
search_space=get_branin_search_space(),
)
# Sobol step should be INITIALIZATION
self.assertEqual(
gs._nodes[0].suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)
# BoTorch step should be OPTIMIZATION
self.assertEqual(
gs._nodes[1].suggested_experiment_status,
ExperimentStatus.OPTIMIZATION,
)

with self.subTest("force_random_search sets INITIALIZATION on Sobol-only GS"):
gs = choose_generation_strategy_legacy(
search_space=get_branin_search_space(),
force_random_search=True,
)
self.assertEqual(len(gs._nodes), 1)
self.assertEqual(
gs._nodes[0].suggested_experiment_status,
ExperimentStatus.INITIALIZATION,
)
4 changes: 4 additions & 0 deletions ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ class TestAxOrchestrator(TestCase):
"transition_criteria="
"[MinTrials(transition_to='GenerationStep_1_BoTorch'), "
"MinTrials(transition_to='GenerationStep_1_BoTorch')], "
"suggested_experiment_status=ExperimentStatus.INITIALIZATION, "
"pausing_criteria="
"[MaxTrialsAwaitingData(threshold=5)]), "
"GenerationNode(name='GenerationStep_1_BoTorch', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=None, "
"suggested_experiment_status=ExperimentStatus.OPTIMIZATION, "
"pausing_criteria="
"[MaxGenerationParallelism(threshold=3)])]), "
"options=OrchestratorOptions(max_pending_trials=10, "
Expand Down Expand Up @@ -2913,12 +2915,14 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator):
"transition_criteria="
"[MinTrials(transition_to='GenerationStep_1_BoTorch'), "
"MinTrials(transition_to='GenerationStep_1_BoTorch')], "
"suggested_experiment_status=ExperimentStatus.INITIALIZATION, "
"pausing_criteria="
"[MaxTrialsAwaitingData(threshold=5)]), "
"GenerationNode(name='GenerationStep_1_BoTorch', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=None, "
"suggested_experiment_status=ExperimentStatus.OPTIMIZATION, "
"pausing_criteria="
"[MaxGenerationParallelism(threshold=3)])]), "
"options=OrchestratorOptions(max_pending_trials=10, "
Expand Down