Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions ax/benchmark/tests/methods/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def _test_mbm_acquisition(self, batch_size: int) -> None:
)
self.assertEqual(method.name, expected_name)
gs = method.generation_strategy
sobol, kg = gs._steps
self.assertEqual(kg.generator, Generators.BOTORCH_MODULAR)
generator_kwargs = none_throws(kg.generator_kwargs)
sobol, kg = gs._nodes
self.assertEqual(kg.generator_spec.generator_enum, Generators.BOTORCH_MODULAR)
generator_kwargs = none_throws(kg.generator_spec.generator_kwargs)
self.assertEqual(generator_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = generator_kwargs["surrogate_spec"]
self.assertEqual(
Expand All @@ -64,11 +64,11 @@ def _test_benchmark_replication_runs(
num_sobol_trials=2,
name="test",
)
n_sobol_trials = method.generation_strategy._steps[0].num_trials
self.assertEqual(n_sobol_trials, 2)
n_sobol_trials_tc = method.generation_strategy._nodes[0].num_trials
self.assertEqual(n_sobol_trials_tc, 2)
self.assertEqual(method.name, "test")
# Only run one non-Sobol trial
n_total_trials = n_sobol_trials + 1
n_total_trials = n_sobol_trials_tc + 1
problem = get_benchmark_problem(
problem_key="ackley4", num_trials=n_total_trials
)
Expand Down Expand Up @@ -104,5 +104,5 @@ def test_sobol(self) -> None:
method = get_sobol_benchmark_method()
self.assertEqual(method.name, "Sobol")
gs = method.generation_strategy
self.assertEqual(len(gs._steps), 1)
self.assertEqual(gs._steps[0].generator, Generators.SOBOL)
self.assertEqual(len(gs._nodes), 1)
self.assertEqual(gs._nodes[0].generator_spec.generator_enum, Generators.SOBOL)
6 changes: 4 additions & 2 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def test_benchmark_method(self) -> None:
self.assertEqual(method.name, "Sobol10")

# test that `fit_tracking_metrics` has been correctly set to False
for step in method.generation_strategy._steps:
for step in method.generation_strategy._nodes:
self.assertFalse(
none_throws(step.generator_kwargs).get("fit_tracking_metrics")
none_throws(step.generator_specs[0].generator_kwargs).get(
"fit_tracking_metrics"
)
)

method = BenchmarkMethod(generation_strategy=gs)
Expand Down
20 changes: 0 additions & 20 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ def __init__(
# Counter to maintain how many arms have been named by this BatchTrial
self._num_arms_created = 0

# If generator run(s) in this trial were generated from a generation
# strategy, this property will be set to the generation step that produced
# the generator run(s).
self._generation_step_index: int | None = None
# NOTE: Please do not store any data related to trial deployment or data-
# fetching in properties. It is intended to only store properties related
# to core Ax functionality and not to any third-system that the trials
Expand Down Expand Up @@ -476,22 +472,6 @@ def _get_default_name(self, arm_index: int | None = None) -> str:
arm_index = self._num_arms_created
return f"{self.index}_{arm_index}"

def _set_generation_step_index(self, generation_step_index: int | None) -> None:
"""Sets the `generation_step_index` property of the trial, to reflect which
generation step of a given generation strategy (if any) produced the generator
run(s) attached to this trial.
"""
if (
self._generation_step_index is not None
and generation_step_index is not None
and self._generation_step_index != generation_step_index
):
raise ValueError(
"Cannot add generator runs from different generation steps to a "
"single trial."
)
self._generation_step_index = generation_step_index

@property
def active_arms(self) -> list[Arm]:
"""All non abandoned arms associated with this trial."""
Expand Down
4 changes: 0 additions & 4 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,6 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial:
self._add_generator_run(generator_run=generator_run)

self._generator_runs.append(generator_run)
if generator_run._generation_step_index is not None:
self._set_generation_step_index(
generation_step_index=generator_run._generation_step_index
)
self._refresh_arms_by_name()
return self

Expand Down
18 changes: 0 additions & 18 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(
adapter_kwargs: dict[str, Any] | None = None,
gen_metadata: TGenMetadata | None = None,
generator_state_after_gen: dict[str, Any] | None = None,
generation_step_index: int | None = None,
candidate_metadata_by_arm_signature: None
| (dict[str, TCandidateMetadata]) = None,
generation_node_name: str | None = None,
Expand Down Expand Up @@ -135,12 +134,6 @@ def __init__(
model when reinstantiating it to continue generation from it,
rather than to reproduce the conditions, in which this generator
run was created.
generation_step_index: Deprecated in favor of generation_node_name.
Optional index of the generation step that produced this generator run.
Applicable only if the generator run was created via a generation
strategy (in which case this index should reflect the index of
generation step in a generation strategy) or a standalone generation
step (in which case this index should be ``-1``).
candidate_metadata_by_arm_signature: Optional dictionary of arm signatures
to model-produced candidate metadata that corresponds to that arm in
this generator run.
Expand Down Expand Up @@ -197,16 +190,6 @@ def __init__(
"candidate metadata, but not among the arms on this GeneratorRun."
)
self._candidate_metadata_by_arm_signature = candidate_metadata_by_arm_signature

# Validate that generation step index is either not set (not from generation
# strategy or ste), is non-negative (from generation step) or is -1 (from a
# standalone generation step that was not a part of a generation strategy).
assert (
generation_step_index is None # Not generation strategy/step
or generation_step_index == -1 # Standalone generation step
or generation_step_index >= 0 # Generation strategy
)
self._generation_step_index = generation_step_index
self._generation_node_name = generation_node_name

@property
Expand Down Expand Up @@ -342,7 +325,6 @@ def clone(self) -> GeneratorRun:
adapter_kwargs=self._adapter_kwargs,
gen_metadata=self._gen_metadata,
generator_state_after_gen=self._generator_state_after_gen,
generation_step_index=self._generation_step_index,
candidate_metadata_by_arm_signature=cand_metadata,
generation_node_name=self._generation_node_name,
)
Expand Down
5 changes: 0 additions & 5 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ def test_add_generator_run(self) -> None:
self.assertEqual(len(self.batch.generator_runs), 1)
self.assertEqual(sum(self.batch.weights), sum(self.weights))

# Overwrite the GS index to not-None.
self.batch._generation_step_index = 0

# one of these arms already exists on the BatchTrial,
# so we should just update its weight
new_arms = [
Expand All @@ -135,8 +132,6 @@ def test_add_generator_run(self) -> None:
float(sum(self.batch.weights)),
float(sum(self.weights)) + new_gr_total_weight,
)
# Check the GS index was not overwritten to None.
self.assertEqual(self.batch._generation_step_index, 0)

def test_InitWithGeneratorRun(self) -> None:
generator_run = GeneratorRun(arms=self.arms, weights=self.weights)
Expand Down
58 changes: 29 additions & 29 deletions ax/generation_strategy/center_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,35 +112,6 @@ def gen(
**gs_gen_kwargs,
)

def _compute_center_params(self) -> TParameterization:
"""Compute the center of the search space."""
search_space = none_throws(self.search_space)
parameters = {}
derived_params = []
for name, p in search_space.parameters.items():
if isinstance(p, RangeParameter):
if p.logit_scale:
# Leverage scipy's numerically stable logit and expit functions
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
elif p.log_scale:
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
else:
center = (float(p.lower) + float(p.upper)) / 2.0
parameters[name] = p.cast(center)
elif isinstance(p, ChoiceParameter):
parameters[name] = p.values[int(len(p.values) / 2)]
elif isinstance(p, FixedParameter):
parameters[name] = p.value
elif isinstance(p, DerivedParameter):
derived_params.append(p)
else:
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
for p in derived_params:
parameters[p.name] = p.compute(parameters=parameters)
if search_space.is_hierarchical:
parameters = search_space._cast_parameterization(parameters=parameters)
return parameters

def get_next_candidate(
self, pending_parameters: list[TParameterization]
) -> TParameterization:
Expand Down Expand Up @@ -171,3 +142,32 @@ def get_next_candidate(
"The generation strategy will fallback to Sobol. "
)
return parameters

def _compute_center_params(self) -> TParameterization:
"""Compute the center of the search space."""
search_space = none_throws(self.search_space)
parameters = {}
derived_params = []
for name, p in search_space.parameters.items():
if isinstance(p, RangeParameter):
if p.logit_scale:
# Leverage scipy's numerically stable logit and expit functions
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
elif p.log_scale:
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
else:
center = (float(p.lower) + float(p.upper)) / 2.0
parameters[name] = p.cast(center)
elif isinstance(p, ChoiceParameter):
parameters[name] = p.values[int(len(p.values) / 2)]
elif isinstance(p, FixedParameter):
parameters[name] = p.value
elif isinstance(p, DerivedParameter):
derived_params.append(p)
else:
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
for p in derived_params:
parameters[p.name] = p.compute(parameters=parameters)
if search_space.is_hierarchical:
parameters = search_space._cast_parameterization(parameters=parameters)
return parameters
14 changes: 7 additions & 7 deletions ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def choose_generation_strategy_legacy(
logger.debug(
f"calculated num_initialization_trials={num_initialization_trials}"
)
steps = []
nodes = []
# `disable_progbar` and jit_compile defaults and overrides
model_is_saasbo = use_saasbo and (suggested_model is Generators.BOTORCH_MODULAR)
if disable_progbar is not None and not model_is_saasbo:
Expand All @@ -493,7 +493,7 @@ def choose_generation_strategy_legacy(
# Note: With `use_all_trials_in_exp=True`, the Sobol step will automatically
# account for any existing trials in the experiment, so we can use the total
# number of initialization trials directly.
steps.append(
nodes.append(
_make_sobol_step(
num_trials=num_initialization_trials,
min_trials_observed=min_sobol_trials_observed,
Expand All @@ -503,7 +503,7 @@ def choose_generation_strategy_legacy(
should_deduplicate=should_deduplicate,
)
)
steps.append(
nodes.append(
_make_botorch_step(
generator=suggested_model,
winsorization_config=winsorization_config,
Expand All @@ -518,17 +518,17 @@ def choose_generation_strategy_legacy(
),
)
# set name for GS
bo_step = steps[-1]
surrogate_spec = bo_step.generator_kwargs.get("surrogate_spec")
bo_step = nodes[-1]
surrogate_spec = bo_step.generator_spec.generator_kwargs.get("surrogate_spec")
name = None
if (
bo_step.generator is Generators.BOTORCH_MODULAR
bo_step.generator_spec.generator_enum is Generators.BOTORCH_MODULAR
and surrogate_spec is not None
and (model_config := surrogate_spec.model_configs[0]).botorch_model_class
== SaasFullyBayesianSingleTaskGP
):
name = f"Sobol+{model_config.name}"
gs = GenerationStrategy(steps=steps, name=name)
gs = GenerationStrategy(nodes=nodes, name=name)
logger.info(
f"Using Bayesian Optimization generation strategy: {gs}. Iterations after"
f" {num_initialization_trials} will take longer to generate due"
Expand Down
2 changes: 0 additions & 2 deletions ax/generation_strategy/external_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from ax.generation_strategy.transition_criterion import TransitionCriterion


# TODO[drfreund]: Introduce a `GenerationNodeInterface` to
# make inheritance/overriding of `GenNode` methods cleaner.
class ExternalGenerationNode(GenerationNode, ABC):
"""A generation node intended to be used with non-Ax methods for
candidate generation.
Expand Down
Loading