Skip to content

Commit 40fd08e

Browse files
Lena Kashtelyanmeta-codesync[bot]
authored andcommitted
Move all step-based GS functionality to nodes (facebook#4746)
Summary: X-link: facebookincubator/MCGrad#94 Pull Request resolved: facebook#4746 Pull Request resolved: facebook#4731 A lot going on here: 1. Make `GenerationStep` a factory for `GenerationNode` by replacing its `__init__(self, ...)` constructor (which would have to return a `GStep`) with a `__new__(cls, ...)` construtor, which can return a `GNode` (magic, thanks Devmate!) 2. Adapt storage: stop storing steps and just treat them as nodes. No step-only fields will be saved going forward. Backward compatibility is handled though. 3. Change a bazillion tests and checks in downstream applications. NOTE: need to remember next steps here: https://www.internalfb.com/diff/D86066476?dst_version_fbid=732904879835890&transaction_fbid=1400151778204976, cc mgarrard sorrybigdiff Differential Revision: D80128678
1 parent 4a99121 commit 40fd08e

31 files changed

+1182
-940
lines changed

ax/benchmark/tests/methods/test_methods.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def _test_mbm_acquisition(self, batch_size: int) -> None:
3737
)
3838
self.assertEqual(method.name, expected_name)
3939
gs = method.generation_strategy
40-
sobol, kg = gs._steps
41-
self.assertEqual(kg.generator, Generators.BOTORCH_MODULAR)
42-
generator_kwargs = none_throws(kg.generator_kwargs)
40+
sobol, kg = gs._nodes
41+
self.assertEqual(kg.generator_spec.generator_enum, Generators.BOTORCH_MODULAR)
42+
generator_kwargs = none_throws(kg.generator_spec.generator_kwargs)
4343
self.assertEqual(generator_kwargs["botorch_acqf_class"], qKnowledgeGradient)
4444
surrogate_spec = generator_kwargs["surrogate_spec"]
4545
self.assertEqual(
@@ -64,11 +64,11 @@ def _test_benchmark_replication_runs(
6464
num_sobol_trials=2,
6565
name="test",
6666
)
67-
n_sobol_trials = method.generation_strategy._steps[0].num_trials
68-
self.assertEqual(n_sobol_trials, 2)
67+
n_sobol_trials_tc = method.generation_strategy._nodes[0].num_trials
68+
self.assertEqual(n_sobol_trials_tc, 2)
6969
self.assertEqual(method.name, "test")
7070
# Only run one non-Sobol trial
71-
n_total_trials = n_sobol_trials + 1
71+
n_total_trials = n_sobol_trials_tc + 1
7272
problem = get_benchmark_problem(
7373
problem_key="ackley4", num_trials=n_total_trials
7474
)
@@ -104,5 +104,5 @@ def test_sobol(self) -> None:
104104
method = get_sobol_benchmark_method()
105105
self.assertEqual(method.name, "Sobol")
106106
gs = method.generation_strategy
107-
self.assertEqual(len(gs._steps), 1)
108-
self.assertEqual(gs._steps[0].generator, Generators.SOBOL)
107+
self.assertEqual(len(gs._nodes), 1)
108+
self.assertEqual(gs._nodes[0].generator_spec.generator_enum, Generators.SOBOL)

ax/benchmark/tests/test_benchmark_method.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def test_benchmark_method(self) -> None:
1818
self.assertEqual(method.name, "Sobol10")
1919

2020
# test that `fit_tracking_metrics` has been correctly set to False
21-
for step in method.generation_strategy._steps:
21+
for step in method.generation_strategy._nodes:
2222
self.assertFalse(
23-
none_throws(step.generator_kwargs).get("fit_tracking_metrics")
23+
none_throws(step.generator_specs[0].generator_kwargs).get(
24+
"fit_tracking_metrics"
25+
)
2426
)
2527

2628
method = BenchmarkMethod(generation_strategy=gs)

ax/core/base_trial.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,6 @@ def __init__(
128128
# Counter to maintain how many arms have been named by this BatchTrial
129129
self._num_arms_created = 0
130130

131-
# If generator run(s) in this trial were generated from a generation
132-
# strategy, this property will be set to the generation step that produced
133-
# the generator run(s).
134-
self._generation_step_index: int | None = None
135131
# NOTE: Please do not store any data related to trial deployment or data-
136132
# fetching in properties. It is intended to only store properties related
137133
# to core Ax functionality and not to any third-system that the trials
@@ -475,22 +471,6 @@ def _get_default_name(self, arm_index: int | None = None) -> str:
475471
arm_index = self._num_arms_created
476472
return f"{self.index}_{arm_index}"
477473

478-
def _set_generation_step_index(self, generation_step_index: int | None) -> None:
479-
"""Sets the `generation_step_index` property of the trial, to reflect which
480-
generation step of a given generation strategy (if any) produced the generator
481-
run(s) attached to this trial.
482-
"""
483-
if (
484-
self._generation_step_index is not None
485-
and generation_step_index is not None
486-
and self._generation_step_index != generation_step_index
487-
):
488-
raise ValueError(
489-
"Cannot add generator runs from different generation steps to a "
490-
"single trial."
491-
)
492-
self._generation_step_index = generation_step_index
493-
494474
@property
495475
def active_arms(self) -> list[Arm]:
496476
"""All non abandoned arms associated with this trial."""

ax/core/batch_trial.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,6 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial:
298298
self._add_generator_run(generator_run=generator_run)
299299

300300
self._generator_runs.append(generator_run)
301-
if generator_run._generation_step_index is not None:
302-
self._set_generation_step_index(
303-
generation_step_index=generator_run._generation_step_index
304-
)
305301
self._refresh_arms_by_name()
306302
return self
307303

ax/core/generator_run.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def __init__(
9797
adapter_kwargs: dict[str, Any] | None = None,
9898
gen_metadata: TGenMetadata | None = None,
9999
generator_state_after_gen: dict[str, Any] | None = None,
100-
generation_step_index: int | None = None,
101100
candidate_metadata_by_arm_signature: None
102101
| (dict[str, TCandidateMetadata]) = None,
103102
generation_node_name: str | None = None,
@@ -135,12 +134,6 @@ def __init__(
135134
model when reinstantiating it to continue generation from it,
136135
rather than to reproduce the conditions, in which this generator
137136
run was created.
138-
generation_step_index: Deprecated in favor of generation_node_name.
139-
Optional index of the generation step that produced this generator run.
140-
Applicable only if the generator run was created via a generation
141-
strategy (in which case this index should reflect the index of
142-
generation step in a generation strategy) or a standalone generation
143-
step (in which case this index should be ``-1``).
144137
candidate_metadata_by_arm_signature: Optional dictionary of arm signatures
145138
to model-produced candidate metadata that corresponds to that arm in
146139
this generator run.
@@ -197,16 +190,6 @@ def __init__(
197190
"candidate metadata, but not among the arms on this GeneratorRun."
198191
)
199192
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature
200-
201-
# Validate that generation step index is either not set (not from generation
202-
# strategy or ste), is non-negative (from generation step) or is -1 (from a
203-
# standalone generation step that was not a part of a generation strategy).
204-
assert (
205-
generation_step_index is None # Not generation strategy/step
206-
or generation_step_index == -1 # Standalone generation step
207-
or generation_step_index >= 0 # Generation strategy
208-
)
209-
self._generation_step_index = generation_step_index
210193
self._generation_node_name = generation_node_name
211194

212195
@property
@@ -342,7 +325,6 @@ def clone(self) -> GeneratorRun:
342325
adapter_kwargs=self._adapter_kwargs,
343326
gen_metadata=self._gen_metadata,
344327
generator_state_after_gen=self._generator_state_after_gen,
345-
generation_step_index=self._generation_step_index,
346328
candidate_metadata_by_arm_signature=cand_metadata,
347329
generation_node_name=self._generation_node_name,
348330
)

ax/core/tests/test_batch_trial.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,6 @@ def test_add_generator_run(self) -> None:
115115
self.assertEqual(len(self.batch.generator_runs), 1)
116116
self.assertEqual(sum(self.batch.weights), sum(self.weights))
117117

118-
# Overwrite the GS index to not-None.
119-
self.batch._generation_step_index = 0
120-
121118
# one of these arms already exists on the BatchTrial,
122119
# so we should just update its weight
123120
new_arms = [
@@ -135,8 +132,6 @@ def test_add_generator_run(self) -> None:
135132
float(sum(self.batch.weights)),
136133
float(sum(self.weights)) + new_gr_total_weight,
137134
)
138-
# Check the GS index was not overwritten to None.
139-
self.assertEqual(self.batch._generation_step_index, 0)
140135

141136
def test_InitWithGeneratorRun(self) -> None:
142137
generator_run = GeneratorRun(arms=self.arms, weights=self.weights)

ax/generation_strategy/dispatch_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def choose_generation_strategy_legacy(
466466
logger.debug(
467467
f"calculated num_initialization_trials={num_initialization_trials}"
468468
)
469-
steps = []
469+
nodes = []
470470
# `disable_progbar` and jit_compile defaults and overrides
471471
model_is_saasbo = use_saasbo and (suggested_model is Generators.BOTORCH_MODULAR)
472472
if disable_progbar is not None and not model_is_saasbo:
@@ -493,7 +493,7 @@ def choose_generation_strategy_legacy(
493493
# Note: With `use_all_trials_in_exp=True`, the Sobol step will automatically
494494
# account for any existing trials in the experiment, so we can use the total
495495
# number of initialization trials directly.
496-
steps.append(
496+
nodes.append(
497497
_make_sobol_step(
498498
num_trials=num_initialization_trials,
499499
min_trials_observed=min_sobol_trials_observed,
@@ -503,7 +503,7 @@ def choose_generation_strategy_legacy(
503503
should_deduplicate=should_deduplicate,
504504
)
505505
)
506-
steps.append(
506+
nodes.append(
507507
_make_botorch_step(
508508
generator=suggested_model,
509509
winsorization_config=winsorization_config,
@@ -518,17 +518,17 @@ def choose_generation_strategy_legacy(
518518
),
519519
)
520520
# set name for GS
521-
bo_step = steps[-1]
522-
surrogate_spec = bo_step.generator_kwargs.get("surrogate_spec")
521+
bo_step = nodes[-1]
522+
surrogate_spec = bo_step.generator_spec.generator_kwargs.get("surrogate_spec")
523523
name = None
524524
if (
525-
bo_step.generator is Generators.BOTORCH_MODULAR
525+
bo_step.generator_spec.generator_enum is Generators.BOTORCH_MODULAR
526526
and surrogate_spec is not None
527527
and (model_config := surrogate_spec.model_configs[0]).botorch_model_class
528528
== SaasFullyBayesianSingleTaskGP
529529
):
530530
name = f"Sobol+{model_config.name}"
531-
gs = GenerationStrategy(steps=steps, name=name)
531+
gs = GenerationStrategy(nodes=nodes, name=name)
532532
logger.info(
533533
f"Using Bayesian Optimization generation strategy: {gs}. Iterations after"
534534
f" {num_initialization_trials} will take longer to generate due"

0 commit comments

Comments
 (0)