Skip to content

Commit f973cb4

Browse files
andycylmetameta-codesync[bot]
authored andcommitted
Search space editing (#4945)
Summary: Pull Request resolved: #4945 This diff adds methods for search space manipulation for AxClient for convenience: 1. add_parameters: add parameters into search spaces 2. update_parameters: update the bounds for existing range parameters in the search space. 3. disable_parameters: Disable parameters in the experiment. This allows narrowing the search space after the experiment has run some trials. Reviewed By: mpolson64 Differential Revision: D93766951 fbshipit-source-id: 1d3cc7bf13668be3727367c8587dd8fcd607d68e
1 parent 00960e9 commit f973cb4

File tree

2 files changed

+301
-0
lines changed

2 files changed

+301
-0
lines changed

ax/service/ax_client.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pandas as pd
2020
import torch
2121
from ax.adapter.prediction_utils import predict_by_features
22+
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig
23+
from ax.api.utils.instantiation.from_config import parameter_from_config
2224
from ax.core.arm import Arm
2325
from ax.core.base_trial import BaseTrial
2426
from ax.core.evaluations_to_data import raw_evaluations_to_data
@@ -27,6 +29,7 @@
2729
from ax.core.multi_type_experiment import MultiTypeExperiment
2830
from ax.core.objective import MultiObjective, Objective
2931
from ax.core.observation import ObservationFeatures
32+
from ax.core.parameter import RangeParameter
3033
from ax.core.runner import Runner
3134
from ax.core.trial import Trial
3235
from ax.core.trial_status import TrialStatus
@@ -516,6 +519,111 @@ def set_search_space(
516519
experiment=self.experiment,
517520
)
518521

522+
def add_parameters(
523+
self,
524+
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
525+
backfill_values: TParameterization,
526+
status_quo_values: TParameterization | None = None,
527+
) -> None:
528+
"""
529+
Add new parameters to the experiment's search space. This allows extending
530+
the search space after the experiment has run some trials.
531+
532+
Backfill values must be provided for all new parameters to ensure existing
533+
trials in the experiment remain valid within the expanded search space. The
534+
backfill values represent the parameter values that were used in the existing
535+
trials.
536+
537+
Args:
538+
parameters: A sequence of parameter configurations to add to the search
539+
space.
540+
backfill_values: Parameter values to assign to existing trials for the
541+
new parameters being added. All new parameter names must have
542+
corresponding backfill values provided.
543+
status_quo_values: Optional parameter values for the new parameters to
544+
use in the status quo (baseline) arm, if one is defined. If None,
545+
the backfill values will be used for the status quo.
546+
"""
547+
parameters_to_add = [
548+
parameter_from_config(parameter_config) for parameter_config in parameters
549+
]
550+
parameter_names = {parameter.name for parameter in parameters_to_add}
551+
missing_backfill_values = parameter_names - backfill_values.keys()
552+
if missing_backfill_values:
553+
raise UserInputError(
554+
"You must provide backfill values for all parameters being added. "
555+
f"Missing values for parameters: {missing_backfill_values}."
556+
)
557+
extra_backfill_values = backfill_values.keys() - parameter_names
558+
if extra_backfill_values:
559+
logger.warning(
560+
"Backfill values provided for parameters not being added: "
561+
f"{extra_backfill_values}. Will ignore these values."
562+
)
563+
for parameter in parameters_to_add:
564+
if parameter.name in backfill_values:
565+
parameter._backfill_value = backfill_values[parameter.name]
566+
self.experiment.add_parameters_to_search_space(
567+
parameters=parameters_to_add,
568+
status_quo_values=status_quo_values,
569+
)
570+
self._save_experiment_to_db_if_possible(experiment=self.experiment)
571+
572+
def disable_parameters(
573+
self,
574+
default_parameter_values: TParameterization,
575+
) -> None:
576+
"""
577+
Disable parameters in the experiment. This allows narrowing the search space
578+
after the experiment has run some trials.
579+
580+
When parameters are disabled, they are effectively removed from the search
581+
space for future trial generation. Existing trials remain valid, and the
582+
disabled parameters are replaced with fixed default values for all subsequent
583+
trials.
584+
585+
Args:
586+
default_parameter_values: Fixed values to use for the disabled parameters
587+
in all future trials. These values will be used for the parameter in
588+
all subsequent trials.
589+
"""
590+
self.experiment.disable_parameters_in_search_space(
591+
default_parameter_values=default_parameter_values
592+
)
593+
self._save_experiment_to_db_if_possible(experiment=self.experiment)
594+
595+
def update_parameters(
596+
self,
597+
parameters: Sequence[RangeParameterConfig],
598+
) -> None:
599+
"""Update parameters in the experiment's search space.
600+
601+
This allows modifying the search space after the experiment has run some
602+
trials.
603+
604+
Args:
605+
parameters: A sequence of ``RangeParameterConfig`` to update in the
606+
search space.
607+
608+
Raises:
609+
UserInputError: If a parameter is not found in the search space or
610+
if the parameter is not a ``RangeParameter``.
611+
"""
612+
search_space = self.experiment.search_space
613+
for parameter in parameters:
614+
if parameter.name not in search_space.parameters:
615+
raise UserInputError(
616+
f"Parameter {parameter.name} not found in search space."
617+
)
618+
if not isinstance(search_space.parameters[parameter.name], RangeParameter):
619+
raise UserInputError(
620+
f"Parameter {parameter.name} is not a RangeParameter."
621+
)
622+
623+
for parameter in parameters:
624+
search_space.update_parameter(parameter=parameter_from_config(parameter))
625+
self._save_experiment_to_db_if_possible(experiment=self.experiment)
626+
519627
@retry_on_exception(
520628
logger=logger,
521629
exception_types=(RuntimeError,),

ax/service/tests/test_ax_client.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121
from ax.adapter.registry import Cont_X_trans, Generators
22+
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig
2223
from ax.core.arm import Arm
2324
from ax.core.data import Data, MAP_KEY
2425
from ax.core.generator_run import GeneratorRun
@@ -1355,6 +1356,198 @@ def test_set_search_space(self) -> None:
13551356
[ParameterConstraint(inequality="x1 <= x2")],
13561357
)
13571358

1359+
def test_update_parameters(self) -> None:
1360+
"""Test that update_parameters correctly updates parameters and raises
1361+
appropriate errors."""
1362+
ax_client = AxClient()
1363+
ax_client.create_experiment(
1364+
name="test_experiment",
1365+
parameters=[
1366+
{
1367+
"name": "x1",
1368+
"type": "range",
1369+
"bounds": [0.0, 1.0],
1370+
"value_type": "float",
1371+
},
1372+
{
1373+
"name": "x2",
1374+
"type": "range",
1375+
"bounds": [1, 10],
1376+
"value_type": "int",
1377+
},
1378+
{
1379+
"name": "x3",
1380+
"type": "choice",
1381+
"values": ["a", "b", "c"],
1382+
},
1383+
],
1384+
is_test=True,
1385+
immutable_search_space_and_opt_config=False,
1386+
)
1387+
1388+
# --- sub-test 1: update RangeParameter bounds (float) ---
1389+
with self.subTest("update_float_range_parameter"):
1390+
ax_client.update_parameters(
1391+
parameters=[
1392+
RangeParameterConfig(
1393+
name="x1",
1394+
bounds=(0.5, 2.0),
1395+
parameter_type="float",
1396+
),
1397+
]
1398+
)
1399+
param = ax_client.experiment.search_space.parameters["x1"]
1400+
self.assertIsInstance(param, RangeParameter)
1401+
assert isinstance(param, RangeParameter)
1402+
self.assertEqual(param.lower, 0.5)
1403+
self.assertEqual(param.upper, 2.0)
1404+
1405+
# --- sub-test 2: update RangeParameter bounds (int) ---
1406+
with self.subTest("update_int_range_parameter"):
1407+
ax_client.update_parameters(
1408+
parameters=[
1409+
RangeParameterConfig(
1410+
name="x2",
1411+
bounds=(5, 20),
1412+
parameter_type="int",
1413+
),
1414+
]
1415+
)
1416+
param = ax_client.experiment.search_space.parameters["x2"]
1417+
self.assertIsInstance(param, RangeParameter)
1418+
assert isinstance(param, RangeParameter)
1419+
self.assertEqual(param.lower, 5)
1420+
self.assertEqual(param.upper, 20)
1421+
1422+
# --- sub-test 3: raises on missing parameter ---
1423+
with self.subTest("raises_on_missing_parameter"):
1424+
with self.assertRaisesRegex(
1425+
UserInputError, "Parameter nonexistent not found in search space"
1426+
):
1427+
ax_client.update_parameters(
1428+
parameters=[
1429+
RangeParameterConfig(
1430+
name="nonexistent",
1431+
bounds=(0.0, 1.0),
1432+
parameter_type="float",
1433+
),
1434+
]
1435+
)
1436+
1437+
# --- sub-test 4: raises on non-RangeParameter ---
1438+
with self.subTest("raises_on_non_range_parameter"):
1439+
with self.assertRaisesRegex(
1440+
UserInputError, "Parameter x3 is not a RangeParameter"
1441+
):
1442+
ax_client.update_parameters(
1443+
parameters=[
1444+
RangeParameterConfig(
1445+
name="x3",
1446+
bounds=(0.0, 1.0),
1447+
parameter_type="float",
1448+
),
1449+
]
1450+
)
1451+
1452+
def test_add_parameters(self) -> None:
1453+
"""Test that add_parameters correctly adds new parameters to the
1454+
search space.
1455+
"""
1456+
ax_client = AxClient()
1457+
ax_client.create_experiment(
1458+
name="test_experiment",
1459+
parameters=[
1460+
{
1461+
"name": "x1",
1462+
"type": "range",
1463+
"bounds": [0.0, 1.0],
1464+
"value_type": "float",
1465+
},
1466+
],
1467+
is_test=True,
1468+
immutable_search_space_and_opt_config=False,
1469+
)
1470+
1471+
ax_client.add_parameters(
1472+
parameters=[
1473+
RangeParameterConfig(
1474+
name="x2",
1475+
bounds=(0.0, 10.0),
1476+
parameter_type="float",
1477+
),
1478+
ChoiceParameterConfig(
1479+
name="x3",
1480+
values=["a", "b", "c"],
1481+
parameter_type="str",
1482+
),
1483+
],
1484+
backfill_values={"x2": 5.0, "x3": "a"},
1485+
)
1486+
1487+
search_space = ax_client.experiment.search_space
1488+
self.assertIn("x1", search_space.parameters)
1489+
self.assertIn("x2", search_space.parameters)
1490+
self.assertIn("x3", search_space.parameters)
1491+
1492+
param_x2 = search_space.parameters["x2"]
1493+
self.assertIsInstance(param_x2, RangeParameter)
1494+
assert isinstance(param_x2, RangeParameter)
1495+
self.assertEqual(param_x2.lower, 0.0)
1496+
self.assertEqual(param_x2.upper, 10.0)
1497+
1498+
param_x3 = search_space.parameters["x3"]
1499+
self.assertIsInstance(param_x3, ChoiceParameter)
1500+
assert isinstance(param_x3, ChoiceParameter)
1501+
self.assertEqual(param_x3.values, ["a", "b", "c"])
1502+
1503+
def test_disable_parameters(self) -> None:
1504+
"""Test that disable_parameters correctly disables parameters in the search
1505+
space."""
1506+
ax_client = AxClient()
1507+
ax_client.create_experiment(
1508+
name="test_experiment",
1509+
parameters=[
1510+
{
1511+
"name": "x1",
1512+
"type": "range",
1513+
"bounds": [0.0, 1.0],
1514+
"value_type": "float",
1515+
},
1516+
{
1517+
"name": "x2",
1518+
"type": "range",
1519+
"bounds": [1, 10],
1520+
"value_type": "int",
1521+
},
1522+
{
1523+
"name": "x3",
1524+
"type": "choice",
1525+
"values": ["a", "b", "c"],
1526+
},
1527+
],
1528+
is_test=True,
1529+
immutable_search_space_and_opt_config=False,
1530+
)
1531+
1532+
ax_client.disable_parameters(default_parameter_values={"x2": 5, "x3": "b"})
1533+
1534+
search_space = ax_client.experiment.search_space
1535+
self.assertIn("x1", search_space.parameters)
1536+
self.assertIn("x2", search_space.parameters)
1537+
self.assertIn("x3", search_space.parameters)
1538+
1539+
param_x1 = search_space.parameters["x1"]
1540+
self.assertIsInstance(param_x1, RangeParameter)
1541+
self.assertFalse(param_x1.is_disabled)
1542+
1543+
param_x2 = search_space.parameters["x2"]
1544+
self.assertTrue(param_x2.is_disabled)
1545+
self.assertEqual(param_x2.default_value, 5)
1546+
1547+
param_x3 = search_space.parameters["x3"]
1548+
self.assertTrue(param_x3.is_disabled)
1549+
self.assertEqual(param_x3.default_value, "b")
1550+
13581551
def test_create_moo_experiment(self) -> None:
13591552
"""Test basic experiment creation."""
13601553
ax_client = AxClient(

0 commit comments

Comments
 (0)