Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# pyre-strict


from typing import Any

import torch
from ax.adapter.registry import Generators
from ax.api.utils.structs import GenerationStrategyDispatchStruct
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.generation_strategy.center_generation_node import CenterGenerationNode
from ax.generation_strategy.dispatch_utils import get_derelativize_config
from ax.generation_strategy.generation_strategy import (
Expand All @@ -21,7 +23,9 @@
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MinTrials
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from pyre_extensions import none_throws


def _get_sobol_node(
Expand Down Expand Up @@ -96,17 +100,33 @@ def _get_mbm_node(
method: str,
torch_device: str | None,
simplify_parameter_changes: bool,
) -> GenerationNode:
model_config: ModelConfig | None = None,
botorch_acqf_class: type[AcquisitionFunction] | None = None,
) -> tuple[GenerationNode, str]:
"""Constructs an MBM node based on the method specified in
``struct``.

The ``SurrogateSpec`` takes the following form for the given method:
- BALANCED: Two model configs: one with MBM defaults, the other with
linear kernel with input warping.
- FAST: An empty model config that utilizes MBM defaults.
Args:
method: The method to use for the MBM node. This can be one of
- "quality": Uses Warped SAAS model.
- "fast": Uses MBM defaults.
- "custom": Uses the provided ``model_config``.
torch_device: The torch device to use for the MBM node.
simplify_parameter_changes: Whether to simplify parameter changes in
the MBM node.
model_config: Optional model config to use for the MBM node.
This is only supported when ``method`` is "custom".
botorch_acqf_class: An optional BoTorch ``AcquisitionFunction`` class
to use for the MBM node.
"""
# Construct the surrogate spec.
if method == "quality":
if method == "custom":
model_config = none_throws(model_config)
model_configs = [model_config]
mbm_name = (
model_config.name if model_config.name is not None else "custom_config"
)
elif method == "quality":
model_configs = [
ModelConfig(
botorch_model_class=SaasFullyBayesianSingleTaskGP,
Expand All @@ -117,36 +137,49 @@ def _get_mbm_node(
name="WarpedSAAS",
)
]
mbm_name = method
elif method == "fast":
model_configs = [ModelConfig(name="MBM defaults")]
mbm_name = method
else:
raise UnsupportedError(f"Unsupported generation method: {method}.")

# Append acquisition function class name to the node name if provided.
if botorch_acqf_class is not None:
mbm_name = f"{mbm_name}+{botorch_acqf_class.__name__}"

device = None if torch_device is None else torch.device(torch_device)

# Construct generator kwargs.
generator_kwargs: dict[str, Any] = {
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
"torch_device": device,
"transform_configs": get_derelativize_config(
derelativize_with_raw_status_quo=True
),
"acquisition_options": {
"prune_irrelevant_parameters": simplify_parameter_changes
},
}
if botorch_acqf_class is not None:
generator_kwargs["botorch_acqf_class"] = botorch_acqf_class

return GenerationNode(
name="MBM",
generator_specs=[
GeneratorSpec(
generator_enum=Generators.BOTORCH_MODULAR,
generator_kwargs={
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
"torch_device": device,
"transform_configs": get_derelativize_config(
derelativize_with_raw_status_quo=True
),
"acquisition_options": {
"prune_irrelevant_parameters": simplify_parameter_changes
},
},
generator_kwargs=generator_kwargs,
)
],
should_deduplicate=True,
)
), mbm_name


def choose_generation_strategy(
struct: GenerationStrategyDispatchStruct,
model_config: ModelConfig | None = None,
botorch_acqf_class: type[AcquisitionFunction] | None = None,
) -> GenerationStrategy:
"""
Choose a generation strategy based on the properties of the experiment and the
Expand All @@ -159,10 +192,26 @@ def choose_generation_strategy(
struct: A ``GenerationStrategyDispatchStruct``
object that informs
the choice of generation strategy.
model_config: An optional ``ModelConfig`` to use for the Bayesian optimization
phase. This must be provided when ``struct.method`` is ``"custom"``, and
must not be provided otherwise.
botorch_acqf_class: An optional BoTorch ``AcquisitionFunction`` class to use
for the Bayesian optimization phase. When provided, it will be passed as a
model kwarg to the MBM node and its name will be appended to the node name.

Returns:
A generation strategy.
"""
# Validate model_config usage.
if struct.method == "custom":
if model_config is None:
raise UserInputError("model_config must be provided when method='custom'.")
elif model_config is not None:
raise UserInputError(
"model_config should only be provided when method='custom'. "
f"Got method='{struct.method}'."
)

# Handle the random search case.
if struct.method == "random_search":
nodes = [
Expand All @@ -178,10 +227,12 @@ def choose_generation_strategy(
]
gs_name = "QuasiRandomSearch"
else:
mbm_node = _get_mbm_node(
mbm_node, mbm_name = _get_mbm_node(
method=struct.method,
torch_device=struct.torch_device,
simplify_parameter_changes=struct.simplify_parameter_changes,
model_config=model_config,
botorch_acqf_class=botorch_acqf_class,
)
if (
struct.initialization_budget is None
Expand All @@ -198,10 +249,10 @@ def choose_generation_strategy(
),
mbm_node,
]
gs_name = f"Sobol+MBM:{struct.method}"
gs_name = f"Sobol+MBM:{mbm_name}"
else:
nodes = [mbm_node]
gs_name = f"MBM:{struct.method}"
gs_name = f"MBM:{mbm_name}"
if struct.initialize_with_center and (
struct.initialization_budget is None or struct.initialization_budget > 0
):
Expand Down
7 changes: 6 additions & 1 deletion ax/api/utils/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class GenerationStrategyDispatchStruct:
- ``"random_search"``, primarily intended for pure exploration
experiments, this method utilizes quasi-random Sobol sequences
for candidate generation.
- ``"custom"``, allows using a custom ``ModelConfig`` for the
Bayesian optimization phase. When using this method, the
``model_config`` argument must be provided to
``choose_generation_strategy``. This is an advanced option
and should not be considered a part of the public API.
initialization_budget: The number of trials to use for initialization.
If ``None``, a default budget of 5 trials is used.
initialization_random_seed: The random seed to use with the Sobol generator
Expand Down Expand Up @@ -87,7 +92,7 @@ class GenerationStrategyDispatchStruct:
irrelevant parameters.
"""

method: Literal["quality", "fast", "random_search"] = "fast"
method: Literal["quality", "fast", "random_search", "custom"] = "fast"
# Initialization options
initialization_budget: int | None = None
initialization_random_seed: int | None = None
Expand Down
118 changes: 118 additions & 0 deletions ax/api/utils/tests/test_generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ax.api.utils.structs import GenerationStrategyDispatchStruct
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
from ax.generation_strategy.center_generation_node import CenterGenerationNode
from ax.generation_strategy.dispatch_utils import get_derelativize_config
from ax.generation_strategy.transition_criterion import MinTrials
Expand All @@ -27,7 +28,9 @@
)
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.utils import run_trials_with_gs
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP
from pyre_extensions import assert_is_instance, none_throws


Expand Down Expand Up @@ -252,3 +255,118 @@ def test_gs_simplify_parameter_changes(self) -> None:
mbm_spec.generator_kwargs["acquisition_options"],
{"prune_irrelevant_parameters": simplify},
)

def test_choose_gs_custom_with_model_config(self) -> None:
"""Test that custom method works with a provided ModelConfig."""
custom_model_config = ModelConfig(
botorch_model_class=EnsembleMapSaasSingleTaskGP,
name="MAPSAAS",
)
struct = GenerationStrategyDispatchStruct(
method="custom",
initialization_budget=3,
initialize_with_center=False,
torch_device="cpu",
)
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
self.assertEqual(len(gs._nodes), 2)
self.assertEqual(gs.name, "Sobol+MBM:MAPSAAS")

# Check the MBM node uses the custom model config.
mbm_node = gs._nodes[1]
self.assertEqual(len(mbm_node.generator_specs), 1)
mbm_spec = mbm_node.generator_specs[0]
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
expected_ss = SurrogateSpec(model_configs=[custom_model_config])
self.assertEqual(
mbm_spec.generator_kwargs["surrogate_spec"],
expected_ss,
)
self.assertEqual(
mbm_spec.generator_kwargs["torch_device"],
torch.device("cpu"),
)

def test_choose_gs_custom_without_name(self) -> None:
"""Test that custom method works with unnamed ModelConfig."""
custom_model_config = ModelConfig(
botorch_model_class=SaasFullyBayesianSingleTaskGP,
# No name provided.
)
struct = GenerationStrategyDispatchStruct(
method="custom",
initialization_budget=3,
initialize_with_center=False,
)
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
# Should use "custom_config" as the default name.
self.assertEqual(gs.name, "Sobol+MBM:custom_config")

def test_choose_gs_custom_model_config_validation(self) -> None:
"""Test validation of model_config and custom method pairing."""
# Test that custom method raises an error when model_config is not provided.
struct = GenerationStrategyDispatchStruct(method="custom")
with self.assertRaisesRegex(
UserInputError,
"model_config must be provided when method='custom'.",
):
choose_generation_strategy(struct=struct)

# Test that providing model_config without custom method raises an error.
custom_model_config = ModelConfig(name="SomeConfig")
struct = GenerationStrategyDispatchStruct(method="fast")
with self.assertRaisesRegex(
UserInputError,
"model_config should only be provided when method='custom'. "
"Got method='fast'.",
):
choose_generation_strategy(struct=struct, model_config=custom_model_config)

def test_choose_gs_with_custom_botorch_acqf_class(self) -> None:
"""Test that custom botorch_acqf_class is properly passed to generator kwargs
and appended to the node name. Tests both fast and custom methods.
"""
for method, model_config, expected_name in [
("fast", None, "Sobol+MBM:fast+qLogNoisyExpectedImprovement"),
(
"custom",
ModelConfig(
botorch_model_class=EnsembleMapSaasSingleTaskGP,
name="MAPSAAS",
),
"Sobol+MBM:MAPSAAS+qLogNoisyExpectedImprovement",
),
]:
with self.subTest(method=method):
struct = GenerationStrategyDispatchStruct(
method=method, # pyre-ignore [6]
initialization_budget=3,
initialize_with_center=False,
)
gs = choose_generation_strategy(
struct=struct,
model_config=model_config,
botorch_acqf_class=qLogNoisyExpectedImprovement,
)
# Check that the name includes the acquisition function class name.
self.assertEqual(gs.name, expected_name)

# Check that MBM node generator kwargs include the botorch_acqf_class.
mbm_node = gs._nodes[1]
self.assertEqual(len(mbm_node.generator_specs), 1)
mbm_spec = mbm_node.generator_specs[0]
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
self.assertEqual(
mbm_spec.generator_kwargs["botorch_acqf_class"],
qLogNoisyExpectedImprovement,
)
# Check surrogate spec uses expected model config.
expected_model_config = (
model_config
if model_config is not None
else ModelConfig(name="MBM defaults")
)
expected_ss = SurrogateSpec(model_configs=[expected_model_config])
self.assertEqual(
mbm_spec.generator_kwargs["surrogate_spec"], expected_ss
)
34 changes: 34 additions & 0 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import unittest
from logging import Logger
from typing import Any

Expand Down Expand Up @@ -62,6 +63,7 @@
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -337,6 +339,38 @@ def sobol_gpei_generation_node_gs(
return sobol_mbm_GS_nodes


def check_sobol_node(
test_case: unittest.TestCase,
gs: GenerationStrategy,
expected_num_trials: int,
expected_min_trials_observed: int | None = None,
) -> None:
"""Helper to check common Sobol node properties.

Args:
test_case: The test case instance for assertions.
gs: The generation strategy to check.
expected_num_trials: The expected number of trials (threshold).
expected_min_trials_observed: The expected min_trials_observed threshold.
If None, the check is skipped.
"""
sobol_node = gs._nodes[0]
test_case.assertEqual(
sobol_node.generator_specs[0].generator_enum, Generators.SOBOL
)
# First MinTrials criterion has the num_trials threshold.
test_case.assertEqual(
assert_is_instance(sobol_node.transition_criteria[0], MinTrials).threshold,
expected_num_trials,
)
if expected_min_trials_observed is not None:
# Second MinTrials criterion has the min_trials_observed threshold.
test_case.assertEqual(
assert_is_instance(sobol_node.transition_criteria[1], MinTrials).threshold,
expected_min_trials_observed,
)


def get_sobol_MBM_MTGP_gs() -> GenerationStrategy:
return GenerationStrategy(
nodes=[
Expand Down