Skip to content

Commit 269afe8

Browse files
lena-kashtelyanmeta-codesync[bot]
authored andcommitted
Move all step-based GS functionality to nodes (facebook#4731)
Summary: Pull Request resolved: facebook#4731 A lot going on here: 1. Make `GenerationStep` a factory for `GenerationNode` by replacing its `__init__(self, ...)` constructor (which would have to return a `GStep`) with a `__new__(cls, ...)` construtor, which can return a `GNode` (magic, thanks Devmate!) 2. Adapt storage: stop storing steps and just treat them as nodes. No step-only fields will be saved going forward. Backward compatibility is handled though. 3. Change a bazillion tests and checks in downstream applications. NOTE: need to remember next steps here: https://www.internalfb.com/diff/D86066476?dst_version_fbid=732904879835890&transaction_fbid=1400151778204976, cc mgarrard sorrybigdiff Differential Revision: D80128678 Privacy Context Container: L1307644
1 parent 1afc5b8 commit 269afe8

33 files changed

+1189
-945
lines changed

ax/benchmark/tests/methods/test_methods.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def _test_mbm_acquisition(self, batch_size: int) -> None:
3737
)
3838
self.assertEqual(method.name, expected_name)
3939
gs = method.generation_strategy
40-
sobol, kg = gs._steps
41-
self.assertEqual(kg.generator, Generators.BOTORCH_MODULAR)
42-
generator_kwargs = none_throws(kg.generator_kwargs)
40+
sobol, kg = gs._nodes
41+
self.assertEqual(kg.generator_spec.generator_enum, Generators.BOTORCH_MODULAR)
42+
generator_kwargs = none_throws(kg.generator_spec.generator_kwargs)
4343
self.assertEqual(generator_kwargs["botorch_acqf_class"], qKnowledgeGradient)
4444
surrogate_spec = generator_kwargs["surrogate_spec"]
4545
self.assertEqual(
@@ -64,11 +64,11 @@ def _test_benchmark_replication_runs(
6464
num_sobol_trials=2,
6565
name="test",
6666
)
67-
n_sobol_trials = method.generation_strategy._steps[0].num_trials
68-
self.assertEqual(n_sobol_trials, 2)
67+
n_sobol_trials_tc = method.generation_strategy._nodes[0].num_trials
68+
self.assertEqual(n_sobol_trials_tc, 2)
6969
self.assertEqual(method.name, "test")
7070
# Only run one non-Sobol trial
71-
n_total_trials = n_sobol_trials + 1
71+
n_total_trials = n_sobol_trials_tc + 1
7272
problem = get_benchmark_problem(
7373
problem_key="ackley4", num_trials=n_total_trials
7474
)
@@ -104,5 +104,5 @@ def test_sobol(self) -> None:
104104
method = get_sobol_benchmark_method()
105105
self.assertEqual(method.name, "Sobol")
106106
gs = method.generation_strategy
107-
self.assertEqual(len(gs._steps), 1)
108-
self.assertEqual(gs._steps[0].generator, Generators.SOBOL)
107+
self.assertEqual(len(gs._nodes), 1)
108+
self.assertEqual(gs._nodes[0].generator_spec.generator_enum, Generators.SOBOL)

ax/benchmark/tests/test_benchmark_method.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def test_benchmark_method(self) -> None:
1818
self.assertEqual(method.name, "Sobol10")
1919

2020
# test that `fit_tracking_metrics` has been correctly set to False
21-
for step in method.generation_strategy._steps:
21+
for step in method.generation_strategy._nodes:
2222
self.assertFalse(
23-
none_throws(step.generator_kwargs).get("fit_tracking_metrics")
23+
none_throws(step.generator_specs[0].generator_kwargs).get(
24+
"fit_tracking_metrics"
25+
)
2426
)
2527

2628
method = BenchmarkMethod(generation_strategy=gs)

ax/core/base_trial.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ def __init__(
132132
# Counter to maintain how many arms have been named by this BatchTrial
133133
self._num_arms_created = 0
134134

135-
# If generator run(s) in this trial were generated from a generation
136-
# strategy, this property will be set to the generation step that produced
137-
# the generator run(s).
138-
self._generation_step_index: int | None = None
139135
# NOTE: Please do not store any data related to trial deployment or data-
140136
# fetching in properties. It is intended to only store properties related
141137
# to core Ax functionality and not to any third-system that the trials
@@ -483,22 +479,6 @@ def _get_default_name(self, arm_index: int | None = None) -> str:
483479
arm_index = self._num_arms_created
484480
return f"{self.index}_{arm_index}"
485481

486-
def _set_generation_step_index(self, generation_step_index: int | None) -> None:
487-
"""Sets the `generation_step_index` property of the trial, to reflect which
488-
generation step of a given generation strategy (if any) produced the generator
489-
run(s) attached to this trial.
490-
"""
491-
if (
492-
self._generation_step_index is not None
493-
and generation_step_index is not None
494-
and self._generation_step_index != generation_step_index
495-
):
496-
raise ValueError(
497-
"Cannot add generator runs from different generation steps to a "
498-
"single trial."
499-
)
500-
self._generation_step_index = generation_step_index
501-
502482
@property
503483
def active_arms(self) -> list[Arm]:
504484
"""All non abandoned arms associated with this trial."""

ax/core/batch_trial.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,6 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial:
298298
self._add_generator_run(generator_run=generator_run)
299299

300300
self._generator_runs.append(generator_run)
301-
if generator_run._generation_step_index is not None:
302-
self._set_generation_step_index(
303-
generation_step_index=generator_run._generation_step_index
304-
)
305301
self._refresh_arms_by_name()
306302
return self
307303

ax/core/generator_run.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def __init__(
9797
adapter_kwargs: dict[str, Any] | None = None,
9898
gen_metadata: TGenMetadata | None = None,
9999
generator_state_after_gen: dict[str, Any] | None = None,
100-
generation_step_index: int | None = None,
101100
candidate_metadata_by_arm_signature: None
102101
| (dict[str, TCandidateMetadata]) = None,
103102
generation_node_name: str | None = None,
@@ -135,12 +134,6 @@ def __init__(
135134
model when reinstantiating it to continue generation from it,
136135
rather than to reproduce the conditions, in which this generator
137136
run was created.
138-
generation_step_index: Deprecated in favor of generation_node_name.
139-
Optional index of the generation step that produced this generator run.
140-
Applicable only if the generator run was created via a generation
141-
strategy (in which case this index should reflect the index of
142-
generation step in a generation strategy) or a standalone generation
143-
step (in which case this index should be ``-1``).
144137
candidate_metadata_by_arm_signature: Optional dictionary of arm signatures
145138
to model-produced candidate metadata that corresponds to that arm in
146139
this generator run.
@@ -197,16 +190,6 @@ def __init__(
197190
"candidate metadata, but not among the arms on this GeneratorRun."
198191
)
199192
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature
200-
201-
# Validate that generation step index is either not set (not from generation
202-
# strategy or ste), is non-negative (from generation step) or is -1 (from a
203-
# standalone generation step that was not a part of a generation strategy).
204-
assert (
205-
generation_step_index is None # Not generation strategy/step
206-
or generation_step_index == -1 # Standalone generation step
207-
or generation_step_index >= 0 # Generation strategy
208-
)
209-
self._generation_step_index = generation_step_index
210193
self._generation_node_name = generation_node_name
211194

212195
@property
@@ -342,7 +325,6 @@ def clone(self) -> GeneratorRun:
342325
adapter_kwargs=self._adapter_kwargs,
343326
gen_metadata=self._gen_metadata,
344327
generator_state_after_gen=self._generator_state_after_gen,
345-
generation_step_index=self._generation_step_index,
346328
candidate_metadata_by_arm_signature=cand_metadata,
347329
generation_node_name=self._generation_node_name,
348330
)

ax/core/tests/test_batch_trial.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,6 @@ def test_add_generator_run(self) -> None:
115115
self.assertEqual(len(self.batch.generator_runs), 1)
116116
self.assertEqual(sum(self.batch.weights), sum(self.weights))
117117

118-
# Overwrite the GS index to not-None.
119-
self.batch._generation_step_index = 0
120-
121118
# one of these arms already exists on the BatchTrial,
122119
# so we should just update its weight
123120
new_arms = [
@@ -135,8 +132,6 @@ def test_add_generator_run(self) -> None:
135132
float(sum(self.batch.weights)),
136133
float(sum(self.weights)) + new_gr_total_weight,
137134
)
138-
# Check the GS index was not overwritten to None.
139-
self.assertEqual(self.batch._generation_step_index, 0)
140135

141136
def test_InitWithGeneratorRun(self) -> None:
142137
generator_run = GeneratorRun(arms=self.arms, weights=self.weights)

ax/generation_strategy/center_generation_node.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,35 +112,6 @@ def gen(
112112
**gs_gen_kwargs,
113113
)
114114

115-
def _compute_center_params(self) -> TParameterization:
116-
"""Compute the center of the search space."""
117-
search_space = none_throws(self.search_space)
118-
parameters = {}
119-
derived_params = []
120-
for name, p in search_space.parameters.items():
121-
if isinstance(p, RangeParameter):
122-
if p.logit_scale:
123-
# Leverage scipy's numerically stable logit and expit functions
124-
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
125-
elif p.log_scale:
126-
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
127-
else:
128-
center = (float(p.lower) + float(p.upper)) / 2.0
129-
parameters[name] = p.cast(center)
130-
elif isinstance(p, ChoiceParameter):
131-
parameters[name] = p.values[int(len(p.values) / 2)]
132-
elif isinstance(p, FixedParameter):
133-
parameters[name] = p.value
134-
elif isinstance(p, DerivedParameter):
135-
derived_params.append(p)
136-
else:
137-
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
138-
for p in derived_params:
139-
parameters[p.name] = p.compute(parameters=parameters)
140-
if search_space.is_hierarchical:
141-
parameters = search_space._cast_parameterization(parameters=parameters)
142-
return parameters
143-
144115
def get_next_candidate(
145116
self, pending_parameters: list[TParameterization]
146117
) -> TParameterization:
@@ -171,3 +142,32 @@ def get_next_candidate(
171142
"The generation strategy will fallback to Sobol. "
172143
)
173144
return parameters
145+
146+
def _compute_center_params(self) -> TParameterization:
147+
"""Compute the center of the search space."""
148+
search_space = none_throws(self.search_space)
149+
parameters = {}
150+
derived_params = []
151+
for name, p in search_space.parameters.items():
152+
if isinstance(p, RangeParameter):
153+
if p.logit_scale:
154+
# Leverage scipy's numerically stable logit and expit functions
155+
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
156+
elif p.log_scale:
157+
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
158+
else:
159+
center = (float(p.lower) + float(p.upper)) / 2.0
160+
parameters[name] = p.cast(center)
161+
elif isinstance(p, ChoiceParameter):
162+
parameters[name] = p.values[int(len(p.values) / 2)]
163+
elif isinstance(p, FixedParameter):
164+
parameters[name] = p.value
165+
elif isinstance(p, DerivedParameter):
166+
derived_params.append(p)
167+
else:
168+
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
169+
for p in derived_params:
170+
parameters[p.name] = p.compute(parameters=parameters)
171+
if search_space.is_hierarchical:
172+
parameters = search_space._cast_parameterization(parameters=parameters)
173+
return parameters

ax/generation_strategy/dispatch_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def choose_generation_strategy_legacy(
466466
logger.debug(
467467
f"calculated num_initialization_trials={num_initialization_trials}"
468468
)
469-
steps = []
469+
nodes = []
470470
# `disable_progbar` and jit_compile defaults and overrides
471471
model_is_saasbo = use_saasbo and (suggested_model is Generators.BOTORCH_MODULAR)
472472
if disable_progbar is not None and not model_is_saasbo:
@@ -493,7 +493,7 @@ def choose_generation_strategy_legacy(
493493
# Note: With `use_all_trials_in_exp=True`, the Sobol step will automatically
494494
# account for any existing trials in the experiment, so we can use the total
495495
# number of initialization trials directly.
496-
steps.append(
496+
nodes.append(
497497
_make_sobol_step(
498498
num_trials=num_initialization_trials,
499499
min_trials_observed=min_sobol_trials_observed,
@@ -503,7 +503,7 @@ def choose_generation_strategy_legacy(
503503
should_deduplicate=should_deduplicate,
504504
)
505505
)
506-
steps.append(
506+
nodes.append(
507507
_make_botorch_step(
508508
generator=suggested_model,
509509
winsorization_config=winsorization_config,
@@ -518,17 +518,17 @@ def choose_generation_strategy_legacy(
518518
),
519519
)
520520
# set name for GS
521-
bo_step = steps[-1]
522-
surrogate_spec = bo_step.generator_kwargs.get("surrogate_spec")
521+
bo_step = nodes[-1]
522+
surrogate_spec = bo_step.generator_spec.generator_kwargs.get("surrogate_spec")
523523
name = None
524524
if (
525-
bo_step.generator is Generators.BOTORCH_MODULAR
525+
bo_step.generator_spec.generator_enum is Generators.BOTORCH_MODULAR
526526
and surrogate_spec is not None
527527
and (model_config := surrogate_spec.model_configs[0]).botorch_model_class
528528
== SaasFullyBayesianSingleTaskGP
529529
):
530530
name = f"Sobol+{model_config.name}"
531-
gs = GenerationStrategy(steps=steps, name=name)
531+
gs = GenerationStrategy(nodes=nodes, name=name)
532532
logger.info(
533533
f"Using Bayesian Optimization generation strategy: {gs}. Iterations after"
534534
f" {num_initialization_trials} will take longer to generate due"

ax/generation_strategy/external_generation_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from ax.generation_strategy.transition_criterion import TransitionCriterion
2424

2525

26-
# TODO[drfreund]: Introduce a `GenerationNodeInterface` to
27-
# make inheritance/overriding of `GenNode` methods cleaner.
2826
class ExternalGenerationNode(GenerationNode, ABC):
2927
"""A generation node intended to be used with non-Ax methods for
3028
candidate generation.

0 commit comments

Comments
 (0)