Skip to content

Commit 62350f1

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Add AutoTransitionAfterGenCriterion to storage (facebook#2614)
Summary: Pull Request resolved: facebook#2614 During development, it became clear that I never added AutoTransitionAfterGenCriterion to the encoder/decoder registry. This is a simple diff to add that Reviewed By: saitcakmak Differential Revision: D60419201 fbshipit-source-id: 6cd6bc70633ba85206ac5ba0b084d1a716fa446f
1 parent 5ae965f commit 62350f1

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

ax/storage/json_store/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
from ax.modelbridge.registry import ModelRegistryBase
9393
from ax.modelbridge.transforms.base import Transform
9494
from ax.modelbridge.transition_criterion import (
95+
AutoTransitionAfterGenCriterion,
9596
MaxGenerationParallelism,
9697
MaxTrials,
9798
MinimumPreferenceOccurances,
@@ -183,6 +184,7 @@
183184
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
184185
AugmentedBraninMetric: metric_to_dict,
185186
AugmentedHartmann6Metric: metric_to_dict,
187+
AutoTransitionAfterGenCriterion: transition_criterion_to_dict,
186188
BatchTrial: batch_to_dict,
187189
BenchmarkMetric: metric_to_dict,
188190
BoTorchModel: botorch_model_to_dict,
@@ -290,6 +292,7 @@
290292
"AndEarlyStoppingStrategy": AndEarlyStoppingStrategy,
291293
"AugmentedBraninMetric": AugmentedBraninMetric,
292294
"AugmentedHartmann6Metric": AugmentedHartmann6Metric,
295+
"AutoTransitionAfterGenCriterion": AutoTransitionAfterGenCriterion,
293296
"Arm": Arm,
294297
"AggregatedBenchmarkResult": AggregatedBenchmarkResult,
295298
"BatchTrial": BatchTrial,

ax/storage/json_store/tests/test_json_store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@
181181
"GenerationStrategy",
182182
partial(sobol_gpei_generation_node_gs, with_model_selection=True),
183183
),
184+
(
185+
"GenerationStrategy",
186+
partial(sobol_gpei_generation_node_gs, with_auto_transition=True),
187+
),
184188
("GeneratorRun", get_generator_run),
185189
("Hartmann6Metric", get_hartmann_metric),
186190
("HierarchicalSearchSpace", get_hierarchical_search_space),

ax/utils/testing/modeling_stubs.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ax.modelbridge.transforms.base import Transform
3232
from ax.modelbridge.transforms.int_to_float import IntToFloat
3333
from ax.modelbridge.transition_criterion import (
34+
AutoTransitionAfterGenCriterion,
3435
MaxGenerationParallelism,
3536
MaxTrials,
3637
MinimumPreferenceOccurances,
@@ -212,6 +213,7 @@ def get_generation_strategy(
212213

213214
def sobol_gpei_generation_node_gs(
214215
with_model_selection: bool = False,
216+
with_auto_transition: bool = False,
215217
) -> GenerationStrategy:
216218
"""Returns a basic SOBOL+MBM GS using GenerationNodes for testing.
217219
@@ -255,6 +257,7 @@ def sobol_gpei_generation_node_gs(
255257
not_in_statuses=None,
256258
),
257259
]
260+
alt_mbm_criterion = [AutoTransitionAfterGenCriterion(transition_to="MBM_node")]
258261
step_model_kwargs = {"silently_filter_kwargs": True}
259262
sobol_model_spec = ModelSpec(
260263
model_enum=Models.SOBOL,
@@ -284,12 +287,20 @@ def sobol_gpei_generation_node_gs(
284287
else:
285288
best_model_selector = None
286289

287-
mbm_node = GenerationNode(
288-
node_name="MBM_node",
289-
transition_criteria=mbm_criterion,
290-
model_specs=mbm_model_specs,
291-
best_model_selector=best_model_selector,
292-
)
290+
if with_auto_transition:
291+
mbm_node = GenerationNode(
292+
node_name="MBM_node",
293+
transition_criteria=alt_mbm_criterion,
294+
model_specs=mbm_model_specs,
295+
best_model_selector=best_model_selector,
296+
)
297+
else:
298+
mbm_node = GenerationNode(
299+
node_name="MBM_node",
300+
transition_criteria=mbm_criterion,
301+
model_specs=mbm_model_specs,
302+
best_model_selector=best_model_selector,
303+
)
293304

294305
sobol_mbm_GS_nodes = GenerationStrategy(
295306
name="Sobol+MBM_Nodes",

0 commit comments

Comments
 (0)