Skip to content

Commit d740cb4

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Add custom method with model_config to choose_gs (#4727)
Summary: Pull Request resolved: #4727 This diff adds the ability to use custom `ModelConfig` objects and a custom `botorch_acqf_class` with `choose_generation_strategy` by introducing a new `"custom"` method option. ## Motivation After D83062555, the internal `choose_gs_internal` function now uses the OSS `choose_generation_strategy` for default GS dispatch. However, there was no way to pass custom model configurations (like MAP SAAS) through this function - they had to be handled as special cases outside the dispatch flow. ## Solution Rather than adding specific method types for each model variant (which would turn the Literal type into a "zoo" of options), we add a single `"custom"` method that allows passing a `ModelConfig` as a separate function argument. This keeps the API clean while providing full flexibility. Key changes: - Add `"custom"` to the `method` Literal type in `GenerationStrategyDispatchStruct` - Add optional `model_config` parameter to `choose_generation_strategy()` - Add validation: `model_config` must be provided iff `method="custom"` - Update `_get_mbm_node()` to handle custom model configs Reviewed By: sdaulton Differential Revision: D89906836 fbshipit-source-id: ac7a380abd7236a06a059f3b1a335baf699b9c17
1 parent 783f108 commit d740cb4

File tree

3 files changed

+196
-22
lines changed

3 files changed

+196
-22
lines changed

ax/api/utils/generation_strategy_dispatch.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
# pyre-strict
88

99

10+
from typing import Any
11+
1012
import torch
1113
from ax.adapter.registry import Generators
1214
from ax.api.utils.structs import GenerationStrategyDispatchStruct
1315
from ax.core.trial_status import TrialStatus
14-
from ax.exceptions.core import UnsupportedError
16+
from ax.exceptions.core import UnsupportedError, UserInputError
1517
from ax.generation_strategy.center_generation_node import CenterGenerationNode
1618
from ax.generation_strategy.dispatch_utils import get_derelativize_config
1719
from ax.generation_strategy.generation_strategy import (
@@ -21,7 +23,9 @@
2123
from ax.generation_strategy.generator_spec import GeneratorSpec
2224
from ax.generation_strategy.transition_criterion import MinTrials
2325
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
26+
from botorch.acquisition.acquisition import AcquisitionFunction
2427
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
28+
from pyre_extensions import none_throws
2529

2630

2731
def _get_sobol_node(
@@ -96,17 +100,33 @@ def _get_mbm_node(
96100
method: str,
97101
torch_device: str | None,
98102
simplify_parameter_changes: bool,
99-
) -> GenerationNode:
103+
model_config: ModelConfig | None = None,
104+
botorch_acqf_class: type[AcquisitionFunction] | None = None,
105+
) -> tuple[GenerationNode, str]:
100106
"""Constructs an MBM node based on the method specified in
101107
``struct``.
102108
103-
The ``SurrogateSpec`` takes the following form for the given method:
104-
- BALANCED: Two model configs: one with MBM defaults, the other with
105-
linear kernel with input warping.
106-
- FAST: An empty model config that utilizes MBM defaults.
109+
Args:
110+
method: The method to use for the MBM node. This can be one of
111+
- "quality": Uses Warped SAAS model.
112+
- "fast": Uses MBM defaults.
113+
- "custom": Uses the provided ``model_config``.
114+
torch_device: The torch device to use for the MBM node.
115+
simplify_parameter_changes: Whether to simplify parameter changes in
116+
the MBM node.
117+
model_config: Optional model config to use for the MBM node.
118+
This is only supported when ``method`` is "custom".
119+
botorch_acqf_class: An optional BoTorch ``AcquisitionFunction`` class
120+
to use for the MBM node.
107121
"""
108122
# Construct the surrogate spec.
109-
if method == "quality":
123+
if method == "custom":
124+
model_config = none_throws(model_config)
125+
model_configs = [model_config]
126+
mbm_name = (
127+
model_config.name if model_config.name is not None else "custom_config"
128+
)
129+
elif method == "quality":
110130
model_configs = [
111131
ModelConfig(
112132
botorch_model_class=SaasFullyBayesianSingleTaskGP,
@@ -117,36 +137,49 @@ def _get_mbm_node(
117137
name="WarpedSAAS",
118138
)
119139
]
140+
mbm_name = method
120141
elif method == "fast":
121142
model_configs = [ModelConfig(name="MBM defaults")]
143+
mbm_name = method
122144
else:
123145
raise UnsupportedError(f"Unsupported generation method: {method}.")
124146

147+
# Append acquisition function class name to the node name if provided.
148+
if botorch_acqf_class is not None:
149+
mbm_name = f"{mbm_name}+{botorch_acqf_class.__name__}"
150+
125151
device = None if torch_device is None else torch.device(torch_device)
126152

153+
# Construct generator kwargs.
154+
generator_kwargs: dict[str, Any] = {
155+
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
156+
"torch_device": device,
157+
"transform_configs": get_derelativize_config(
158+
derelativize_with_raw_status_quo=True
159+
),
160+
"acquisition_options": {
161+
"prune_irrelevant_parameters": simplify_parameter_changes
162+
},
163+
}
164+
if botorch_acqf_class is not None:
165+
generator_kwargs["botorch_acqf_class"] = botorch_acqf_class
166+
127167
return GenerationNode(
128168
name="MBM",
129169
generator_specs=[
130170
GeneratorSpec(
131171
generator_enum=Generators.BOTORCH_MODULAR,
132-
generator_kwargs={
133-
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
134-
"torch_device": device,
135-
"transform_configs": get_derelativize_config(
136-
derelativize_with_raw_status_quo=True
137-
),
138-
"acquisition_options": {
139-
"prune_irrelevant_parameters": simplify_parameter_changes
140-
},
141-
},
172+
generator_kwargs=generator_kwargs,
142173
)
143174
],
144175
should_deduplicate=True,
145-
)
176+
), mbm_name
146177

147178

148179
def choose_generation_strategy(
149180
struct: GenerationStrategyDispatchStruct,
181+
model_config: ModelConfig | None = None,
182+
botorch_acqf_class: type[AcquisitionFunction] | None = None,
150183
) -> GenerationStrategy:
151184
"""
152185
Choose a generation strategy based on the properties of the experiment and the
@@ -159,10 +192,26 @@ def choose_generation_strategy(
159192
struct: A ``GenerationStrategyDispatchStruct``
160193
object that informs
161194
the choice of generation strategy.
195+
model_config: An optional ``ModelConfig`` to use for the Bayesian optimization
196+
phase. This must be provided when ``struct.method`` is ``"custom"``, and
197+
must not be provided otherwise.
198+
botorch_acqf_class: An optional BoTorch ``AcquisitionFunction`` class to use
199+
for the Bayesian optimization phase. When provided, it will be passed as a
200+
model kwarg to the MBM node and its name will be appended to the node name.
162201
163202
Returns:
164203
A generation strategy.
165204
"""
205+
# Validate model_config usage.
206+
if struct.method == "custom":
207+
if model_config is None:
208+
raise UserInputError("model_config must be provided when method='custom'.")
209+
elif model_config is not None:
210+
raise UserInputError(
211+
"model_config should only be provided when method='custom'. "
212+
f"Got method='{struct.method}'."
213+
)
214+
166215
# Handle the random search case.
167216
if struct.method == "random_search":
168217
nodes = [
@@ -178,10 +227,12 @@ def choose_generation_strategy(
178227
]
179228
gs_name = "QuasiRandomSearch"
180229
else:
181-
mbm_node = _get_mbm_node(
230+
mbm_node, mbm_name = _get_mbm_node(
182231
method=struct.method,
183232
torch_device=struct.torch_device,
184233
simplify_parameter_changes=struct.simplify_parameter_changes,
234+
model_config=model_config,
235+
botorch_acqf_class=botorch_acqf_class,
185236
)
186237
if (
187238
struct.initialization_budget is None
@@ -198,10 +249,10 @@ def choose_generation_strategy(
198249
),
199250
mbm_node,
200251
]
201-
gs_name = f"Sobol+MBM:{struct.method}"
252+
gs_name = f"Sobol+MBM:{mbm_name}"
202253
else:
203254
nodes = [mbm_node]
204-
gs_name = f"MBM:{struct.method}"
255+
gs_name = f"MBM:{mbm_name}"
205256
if struct.initialize_with_center and (
206257
struct.initialization_budget is None or struct.initialization_budget > 0
207258
):

ax/api/utils/structs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class GenerationStrategyDispatchStruct:
5353
- ``"random_search"``, primarily intended for pure exploration
5454
experiments, this method utilizes quasi-random Sobol sequences
5555
for candidate generation.
56+
- ``"custom"``, allows using a custom ``ModelConfig`` for the
57+
Bayesian optimization phase. When using this method, the
58+
``model_config`` argument must be provided to
59+
``choose_generation_strategy``. This is an advanced option
60+
and should not be considered a part of the public API.
5661
initialization_budget: The number of trials to use for initialization.
5762
If ``None``, a default budget of 5 trials is used.
5863
initialization_random_seed: The random seed to use with the Sobol generator
@@ -87,7 +92,7 @@ class GenerationStrategyDispatchStruct:
8792
irrelevant parameters.
8893
"""
8994

90-
method: Literal["quality", "fast", "random_search"] = "fast"
95+
method: Literal["quality", "fast", "random_search", "custom"] = "fast"
9196
# Initialization options
9297
initialization_budget: int | None = None
9398
initialization_random_seed: int | None = None

ax/api/utils/tests/test_generation_strategy_dispatch.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ax.api.utils.structs import GenerationStrategyDispatchStruct
1717
from ax.core.trial import Trial
1818
from ax.core.trial_status import TrialStatus
19+
from ax.exceptions.core import UserInputError
1920
from ax.generation_strategy.center_generation_node import CenterGenerationNode
2021
from ax.generation_strategy.dispatch_utils import get_derelativize_config
2122
from ax.generation_strategy.transition_criterion import MinTrials
@@ -27,7 +28,9 @@
2728
)
2829
from ax.utils.testing.mock import mock_botorch_optimize
2930
from ax.utils.testing.utils import run_trials_with_gs
31+
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
3032
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
33+
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP
3134
from pyre_extensions import assert_is_instance, none_throws
3235

3336

@@ -252,3 +255,118 @@ def test_gs_simplify_parameter_changes(self) -> None:
252255
mbm_spec.generator_kwargs["acquisition_options"],
253256
{"prune_irrelevant_parameters": simplify},
254257
)
258+
259+
def test_choose_gs_custom_with_model_config(self) -> None:
260+
"""Test that custom method works with a provided ModelConfig."""
261+
custom_model_config = ModelConfig(
262+
botorch_model_class=EnsembleMapSaasSingleTaskGP,
263+
name="MAPSAAS",
264+
)
265+
struct = GenerationStrategyDispatchStruct(
266+
method="custom",
267+
initialization_budget=3,
268+
initialize_with_center=False,
269+
torch_device="cpu",
270+
)
271+
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
272+
self.assertEqual(len(gs._nodes), 2)
273+
self.assertEqual(gs.name, "Sobol+MBM:MAPSAAS")
274+
275+
# Check the MBM node uses the custom model config.
276+
mbm_node = gs._nodes[1]
277+
self.assertEqual(len(mbm_node.generator_specs), 1)
278+
mbm_spec = mbm_node.generator_specs[0]
279+
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
280+
expected_ss = SurrogateSpec(model_configs=[custom_model_config])
281+
self.assertEqual(
282+
mbm_spec.generator_kwargs["surrogate_spec"],
283+
expected_ss,
284+
)
285+
self.assertEqual(
286+
mbm_spec.generator_kwargs["torch_device"],
287+
torch.device("cpu"),
288+
)
289+
290+
def test_choose_gs_custom_without_name(self) -> None:
291+
"""Test that custom method works with unnamed ModelConfig."""
292+
custom_model_config = ModelConfig(
293+
botorch_model_class=SaasFullyBayesianSingleTaskGP,
294+
# No name provided.
295+
)
296+
struct = GenerationStrategyDispatchStruct(
297+
method="custom",
298+
initialization_budget=3,
299+
initialize_with_center=False,
300+
)
301+
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
302+
# Should use "custom_config" as the default name.
303+
self.assertEqual(gs.name, "Sobol+MBM:custom_config")
304+
305+
def test_choose_gs_custom_model_config_validation(self) -> None:
306+
"""Test validation of model_config and custom method pairing."""
307+
# Test that custom method raises an error when model_config is not provided.
308+
struct = GenerationStrategyDispatchStruct(method="custom")
309+
with self.assertRaisesRegex(
310+
UserInputError,
311+
"model_config must be provided when method='custom'.",
312+
):
313+
choose_generation_strategy(struct=struct)
314+
315+
# Test that providing model_config without custom method raises an error.
316+
custom_model_config = ModelConfig(name="SomeConfig")
317+
struct = GenerationStrategyDispatchStruct(method="fast")
318+
with self.assertRaisesRegex(
319+
UserInputError,
320+
"model_config should only be provided when method='custom'. "
321+
"Got method='fast'.",
322+
):
323+
choose_generation_strategy(struct=struct, model_config=custom_model_config)
324+
325+
def test_choose_gs_with_custom_botorch_acqf_class(self) -> None:
326+
"""Test that custom botorch_acqf_class is properly passed to generator kwargs
327+
and appended to the node name. Tests both fast and custom methods.
328+
"""
329+
for method, model_config, expected_name in [
330+
("fast", None, "Sobol+MBM:fast+qLogNoisyExpectedImprovement"),
331+
(
332+
"custom",
333+
ModelConfig(
334+
botorch_model_class=EnsembleMapSaasSingleTaskGP,
335+
name="MAPSAAS",
336+
),
337+
"Sobol+MBM:MAPSAAS+qLogNoisyExpectedImprovement",
338+
),
339+
]:
340+
with self.subTest(method=method):
341+
struct = GenerationStrategyDispatchStruct(
342+
method=method, # pyre-ignore [6]
343+
initialization_budget=3,
344+
initialize_with_center=False,
345+
)
346+
gs = choose_generation_strategy(
347+
struct=struct,
348+
model_config=model_config,
349+
botorch_acqf_class=qLogNoisyExpectedImprovement,
350+
)
351+
# Check that the name includes the acquisition function class name.
352+
self.assertEqual(gs.name, expected_name)
353+
354+
# Check that MBM node generator kwargs include the botorch_acqf_class.
355+
mbm_node = gs._nodes[1]
356+
self.assertEqual(len(mbm_node.generator_specs), 1)
357+
mbm_spec = mbm_node.generator_specs[0]
358+
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
359+
self.assertEqual(
360+
mbm_spec.generator_kwargs["botorch_acqf_class"],
361+
qLogNoisyExpectedImprovement,
362+
)
363+
# Check surrogate spec uses expected model config.
364+
expected_model_config = (
365+
model_config
366+
if model_config is not None
367+
else ModelConfig(name="MBM defaults")
368+
)
369+
expected_ss = SurrogateSpec(model_configs=[expected_model_config])
370+
self.assertEqual(
371+
mbm_spec.generator_kwargs["surrogate_spec"], expected_ss
372+
)

0 commit comments

Comments
 (0)