Skip to content

Commit 91099b4

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add custom method with model_config to choose_gs
Summary: This diff adds the ability to use custom `ModelConfig` objects with `choose_generation_strategy` by introducing a new `"custom"` method option. ## Motivation After D83062555, the internal `choose_gs_internal` function now uses the OSS `choose_generation_strategy` for default GS dispatch. However, there was no way to pass custom model configurations (like MAP SAAS) through this function - they had to be handled as special cases outside the dispatch flow. ## Solution Rather than adding specific method types for each model variant (which would turn the Literal type into a "zoo" of options), we add a single `"custom"` method that allows passing a `ModelConfig` as a separate function argument. This keeps the API clean while providing full flexibility. Key changes: - Add `"custom"` to the `method` Literal type in `GenerationStrategyDispatchStruct` - Add optional `model_config` parameter to `choose_generation_strategy()` - Add validation: `model_config` must be provided iff `method="custom"` - Update `_get_mbm_node()` to handle custom model configs Differential Revision: D89906836
1 parent 94f78c7 commit 91099b4

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed

ax/api/utils/generation_strategy_dispatch.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ax.adapter.registry import Generators
1212
from ax.api.utils.structs import GenerationStrategyDispatchStruct
1313
from ax.core.trial_status import TrialStatus
14-
from ax.exceptions.core import UnsupportedError
14+
from ax.exceptions.core import UnsupportedError, UserInputError
1515
from ax.generation_strategy.center_generation_node import CenterGenerationNode
1616
from ax.generation_strategy.dispatch_utils import get_derelativize_config
1717
from ax.generation_strategy.generation_strategy import (
@@ -22,6 +22,7 @@
2222
from ax.generation_strategy.transition_criterion import MinTrials
2323
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
2424
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
25+
from pyre_extensions import none_throws
2526

2627

2728
def _get_sobol_node(
@@ -96,17 +97,25 @@ def _get_mbm_node(
9697
method: str,
9798
torch_device: str | None,
9899
simplify_parameter_changes: bool,
99-
) -> GenerationNode:
100+
model_config: ModelConfig | None = None,
101+
) -> tuple[GenerationNode, str]:
100102
"""Constructs an MBM node based on the method specified in
101103
``struct``.
102104
103105
The ``SurrogateSpec`` takes the following form for the given method:
104106
- BALANCED: Two model configs: one with MBM defaults, the other with
105107
linear kernel with input warping.
106108
- FAST: An empty model config that utilizes MBM defaults.
109+
- CUSTOM: Uses the provided ``model_config``.
107110
"""
108111
# Construct the surrogate spec.
109-
if method == "quality":
112+
if method == "custom":
113+
model_config = none_throws(model_config)
114+
model_configs = [model_config]
115+
mbm_name = (
116+
model_config.name if model_config.name is not None else "custom_config"
117+
)
118+
elif method == "quality":
110119
model_configs = [
111120
ModelConfig(
112121
botorch_model_class=SaasFullyBayesianSingleTaskGP,
@@ -117,8 +126,10 @@ def _get_mbm_node(
117126
name="WarpedSAAS",
118127
)
119128
]
129+
mbm_name = method
120130
elif method == "fast":
121131
model_configs = [ModelConfig(name="MBM defaults")]
132+
mbm_name = method
122133
else:
123134
raise UnsupportedError(f"Unsupported generation method: {method}.")
124135

@@ -142,11 +153,12 @@ def _get_mbm_node(
142153
)
143154
],
144155
should_deduplicate=True,
145-
)
156+
), mbm_name
146157

147158

148159
def choose_generation_strategy(
149160
struct: GenerationStrategyDispatchStruct,
161+
model_config: ModelConfig | None = None,
150162
) -> GenerationStrategy:
151163
"""
152164
Choose a generation strategy based on the properties of the experiment and the
@@ -159,10 +171,23 @@ def choose_generation_strategy(
159171
struct: A ``GenerationStrategyDispatchStruct``
160172
object that informs
161173
the choice of generation strategy.
174+
model_config: An optional ``ModelConfig`` to use for the Bayesian optimization
175+
phase. This must be provided when ``struct.method`` is ``"custom"``, and
176+
must not be provided otherwise.
162177
163178
Returns:
164179
A generation strategy.
165180
"""
181+
# Validate model_config usage.
182+
if struct.method == "custom":
183+
if model_config is None:
184+
raise UserInputError("model_config must be provided when method='custom'.")
185+
elif model_config is not None:
186+
raise UserInputError(
187+
"model_config should only be provided when method='custom'. "
188+
f"Got method='{struct.method}'."
189+
)
190+
166191
# Handle the random search case.
167192
if struct.method == "random_search":
168193
nodes = [
@@ -178,10 +203,11 @@ def choose_generation_strategy(
178203
]
179204
gs_name = "QuasiRandomSearch"
180205
else:
181-
mbm_node = _get_mbm_node(
206+
mbm_node, mbm_name = _get_mbm_node(
182207
method=struct.method,
183208
torch_device=struct.torch_device,
184209
simplify_parameter_changes=struct.simplify_parameter_changes,
210+
model_config=model_config,
185211
)
186212
if (
187213
struct.initialization_budget is None
@@ -198,10 +224,10 @@ def choose_generation_strategy(
198224
),
199225
mbm_node,
200226
]
201-
gs_name = f"Sobol+MBM:{struct.method}"
227+
gs_name = f"Sobol+MBM:{mbm_name}"
202228
else:
203229
nodes = [mbm_node]
204-
gs_name = f"MBM:{struct.method}"
230+
gs_name = f"MBM:{mbm_name}"
205231
if struct.initialize_with_center and (
206232
struct.initialization_budget is None or struct.initialization_budget > 0
207233
):

ax/api/utils/structs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class GenerationStrategyDispatchStruct:
5353
- ``"random_search"``, primarily intended for pure exploration
5454
experiments, this method utilizes quasi-random Sobol sequences
5555
for candidate generation.
56+
- ``"custom"``, allows using a custom ``ModelConfig`` for the
57+
Bayesian optimization phase. When using this method, the
58+
``model_config`` argument must be provided to
59+
``choose_generation_strategy``. This is an advanced option
60+
and should not be considered a part of the public API.
5661
initialization_budget: The number of trials to use for initialization.
5762
If ``None``, a default budget of 5 trials is used.
5863
initialization_random_seed: The random seed to use with the Sobol generator
@@ -87,7 +92,7 @@ class GenerationStrategyDispatchStruct:
8792
irrelevant parameters.
8893
"""
8994

90-
method: Literal["quality", "fast", "random_search"] = "fast"
95+
method: Literal["quality", "fast", "random_search", "custom"] = "fast"
9196
# Initialization options
9297
initialization_budget: int | None = None
9398
initialization_random_seed: int | None = None

ax/api/utils/tests/test_generation_strategy_dispatch.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ax.api.utils.structs import GenerationStrategyDispatchStruct
1717
from ax.core.trial import Trial
1818
from ax.core.trial_status import TrialStatus
19+
from ax.exceptions.core import UserInputError
1920
from ax.generation_strategy.center_generation_node import CenterGenerationNode
2021
from ax.generation_strategy.dispatch_utils import get_derelativize_config
2122
from ax.generation_strategy.transition_criterion import MinTrials
@@ -28,6 +29,7 @@
2829
from ax.utils.testing.mock import mock_botorch_optimize
2930
from ax.utils.testing.utils import run_trials_with_gs
3031
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
32+
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP
3133
from pyre_extensions import assert_is_instance, none_throws
3234

3335

@@ -252,3 +254,69 @@ def test_gs_simplify_parameter_changes(self) -> None:
252254
mbm_spec.generator_kwargs["acquisition_options"],
253255
{"prune_irrelevant_parameters": simplify},
254256
)
257+
258+
def test_choose_gs_custom_with_model_config(self) -> None:
259+
"""Test that custom method works with a provided ModelConfig."""
260+
custom_model_config = ModelConfig(
261+
botorch_model_class=EnsembleMapSaasSingleTaskGP,
262+
name="MAPSAAS",
263+
)
264+
struct = GenerationStrategyDispatchStruct(
265+
method="custom",
266+
initialization_budget=3,
267+
initialize_with_center=False,
268+
torch_device="cpu",
269+
)
270+
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
271+
self.assertEqual(len(gs._nodes), 2)
272+
self.assertEqual(gs.name, "Sobol+MBM:MAPSAAS")
273+
274+
# Check the MBM node uses the custom model config.
275+
mbm_node = gs._nodes[1]
276+
self.assertEqual(len(mbm_node.generator_specs), 1)
277+
mbm_spec = mbm_node.generator_specs[0]
278+
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
279+
expected_ss = SurrogateSpec(model_configs=[custom_model_config])
280+
self.assertEqual(
281+
mbm_spec.generator_kwargs["surrogate_spec"],
282+
expected_ss,
283+
)
284+
self.assertEqual(
285+
mbm_spec.generator_kwargs["torch_device"],
286+
torch.device("cpu"),
287+
)
288+
289+
def test_choose_gs_custom_without_name(self) -> None:
290+
"""Test that custom method works with unnamed ModelConfig."""
291+
custom_model_config = ModelConfig(
292+
botorch_model_class=SaasFullyBayesianSingleTaskGP,
293+
# No name provided.
294+
)
295+
struct = GenerationStrategyDispatchStruct(
296+
method="custom",
297+
initialization_budget=3,
298+
initialize_with_center=False,
299+
)
300+
gs = choose_generation_strategy(struct=struct, model_config=custom_model_config)
301+
# Should use "custom_config" as the default name.
302+
self.assertEqual(gs.name, "Sobol+MBM:custom_config")
303+
304+
def test_choose_gs_custom_model_config_validation(self) -> None:
305+
"""Test validation of model_config and custom method pairing."""
306+
# Test that custom method raises an error when model_config is not provided.
307+
struct = GenerationStrategyDispatchStruct(method="custom")
308+
with self.assertRaisesRegex(
309+
UserInputError,
310+
"model_config must be provided when method='custom'.",
311+
):
312+
choose_generation_strategy(struct=struct)
313+
314+
# Test that providing model_config without custom method raises an error.
315+
custom_model_config = ModelConfig(name="SomeConfig")
316+
struct = GenerationStrategyDispatchStruct(method="fast")
317+
with self.assertRaisesRegex(
318+
UserInputError,
319+
"model_config should only be provided when method='custom'. "
320+
"Got method='fast'.",
321+
):
322+
choose_generation_strategy(struct=struct, model_config=custom_model_config)

0 commit comments

Comments
 (0)