Skip to content

Commit 68d37c1

Browse files
Lena Kashtelyanmeta-codesync[bot]
authored andcommitted
Deal with total_concurrency and n weirdness: introduce ExperimentDesign.concurrency_limit (#4732)
Summary: Pull Request resolved: #4732 As titled, adding a simple `ExperimentDesign` object. Putting it into properties for serialization for now, so as to not do duplicate work ahead of the storage refactor implementation (and also in case we change things while working on this stack). Differential Revision: D89770462 Privacy Context Container: L1307644
1 parent 17b25cd commit 68d37c1

File tree

6 files changed

+237
-1
lines changed

6 files changed

+237
-1
lines changed

ax/core/experiment.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515
from collections import defaultdict, OrderedDict
1616
from collections.abc import Hashable, Iterable, Mapping, Sequence
17+
from dataclasses import dataclass
1718
from datetime import datetime
1819
from functools import partial, reduce
1920
from typing import Any, cast, Union
@@ -84,6 +85,29 @@
8485
}
8586

8687

88+
@dataclass
89+
class ExperimentDesign:
90+
"""Struct that holds "experiment design" configuration: these are
91+
experiment-level settings that pertain to "how the experiment will be
92+
run or conducted", but are agnostic to the specific evaluation
93+
backend, to which the trials will be deployed.
94+
95+
NOTE: In the future, we might treat concurrency limit as expressed
96+
in terms of "full arm equivalents" as opposed to just "number of arms",
97+
to cover for the multi-fidelity cases.
98+
99+
Args:
100+
concurrency_limit: Maximum number of arms to run within one or
101+
multiple trials, in parallel. In experiments that consist of
102+
`Trial`-s, this is equivalent to the total number of trials
103+
that should run in parallel. In experiments with `BatchTrial`-s,
104+
this total number of arms can be spread across one or
105+
multiple `BatchTrial`-s.
106+
"""
107+
108+
concurrency_limit: int | None = None
109+
110+
87111
class Experiment(Base):
88112
"""Base class for defining an experiment."""
89113

@@ -150,6 +174,12 @@ def __init__(
150174
self._time_created: datetime = datetime.now()
151175
self._trials: dict[int, BaseTrial] = {}
152176
self._properties: dict[str, Any] = properties or {}
177+
self._design: ExperimentDesign = ExperimentDesign()
178+
# Restore ExperimentDesign from properties if present (for deserialization).
179+
# TODO[drfreund, mpolson64]: Replace with proper storage as part of the
180+
# refactor.
181+
if (design_dict := self._properties.pop("design", None)) is not None:
182+
self._design.concurrency_limit = design_dict.get("concurrency_limit")
153183

154184
# Initialize trial type to runner mapping
155185
self._default_trial_type = default_trial_type
@@ -233,6 +263,11 @@ def experiment_type(self, experiment_type: str | None) -> None:
233263
"""Set the type of the experiment."""
234264
self._experiment_type = experiment_type
235265

266+
@property
267+
def design(self) -> ExperimentDesign:
268+
"""The experiment design configuration."""
269+
return self._design
270+
236271
@property
237272
def search_space(self) -> SearchSpace:
238273
"""The search space for this experiment.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from ax.core import Experiment
10+
from ax.core.experiment import ExperimentDesign
11+
from ax.utils.common.testutils import TestCase
12+
from ax.utils.testing.core_stubs import get_branin_search_space
13+
14+
15+
class ExperimentDesignTest(TestCase):
16+
"""Tests for ExperimentDesign and related logic."""
17+
18+
def test_experiment_design_defaults(self) -> None:
19+
"""Test that ExperimentDesign has expected defaults."""
20+
design = ExperimentDesign()
21+
self.assertIsNone(design.concurrency_limit)
22+
23+
def test_experiment_design_with_concurrency_limit(self) -> None:
24+
"""Test ExperimentDesign with concurrency_limit set."""
25+
design = ExperimentDesign(concurrency_limit=10)
26+
self.assertEqual(design.concurrency_limit, 10)
27+
28+
design_zero = ExperimentDesign(concurrency_limit=0)
29+
self.assertEqual(design_zero.concurrency_limit, 0)
30+
31+
def test_experiment_design_property(self) -> None:
32+
"""Test that Experiment.design property returns ExperimentDesign instance."""
33+
experiment = Experiment(
34+
name="test",
35+
search_space=get_branin_search_space(),
36+
)
37+
self.assertIsInstance(experiment.design, ExperimentDesign)
38+
self.assertIsNone(experiment.design.concurrency_limit)
39+
40+
def test_experiment_design_modification(self) -> None:
41+
"""Test that ExperimentDesign can be modified after experiment creation."""
42+
experiment = Experiment(
43+
name="test",
44+
search_space=get_branin_search_space(),
45+
)
46+
# Modify concurrency_limit
47+
experiment.design.concurrency_limit = 5
48+
self.assertEqual(experiment.design.concurrency_limit, 5)
49+
50+
# Set to None
51+
experiment.design.concurrency_limit = None
52+
self.assertIsNone(experiment.design.concurrency_limit)
53+
54+
def test_experiment_design_restore_from_properties(self) -> None:
55+
"""Test that ExperimentDesign is restored from properties during init."""
56+
# Simulate deserializing an experiment with design stored in properties
57+
experiment = Experiment(
58+
name="test",
59+
search_space=get_branin_search_space(),
60+
properties={"design": {"concurrency_limit": 25}},
61+
)
62+
# The design should be restored and the "design" key should be removed
63+
# from properties
64+
self.assertEqual(experiment.design.concurrency_limit, 25)
65+
self.assertNotIn("design", experiment._properties)
66+
67+
def test_experiment_design_restore_from_properties_with_none(self) -> None:
68+
"""Test that ExperimentDesign handles None concurrency_limit in properties."""
69+
experiment = Experiment(
70+
name="test",
71+
search_space=get_branin_search_space(),
72+
properties={"design": {"concurrency_limit": None}},
73+
)
74+
self.assertIsNone(experiment.design.concurrency_limit)
75+
# The "design" key should be removed from properties once it's consumed
76+
# to recreate `ExperimentDesign`.
77+
self.assertNotIn("design", experiment._properties)
78+
79+
def test_experiment_design_restore_from_properties_empty_dict(self) -> None:
80+
"""Test that ExperimentDesign handles empty design dict in properties."""
81+
experiment = Experiment(
82+
name="test",
83+
search_space=get_branin_search_space(),
84+
properties={"design": {}},
85+
)
86+
self.assertIsNone(experiment.design.concurrency_limit)
87+
# The "design" key should be removed from properties once it's consumed
88+
# to recreate `ExperimentDesign`.
89+
self.assertNotIn("design", experiment._properties)
90+
91+
def test_experiment_design_not_affecting_other_properties(self) -> None:
92+
"""Test that ExperimentDesign restoration doesn't affect other properties."""
93+
experiment = Experiment(
94+
name="test",
95+
search_space=get_branin_search_space(),
96+
properties={
97+
"design": {"concurrency_limit": 15},
98+
"custom_property": "custom_value",
99+
"another_property": 42,
100+
},
101+
)
102+
self.assertEqual(experiment.design.concurrency_limit, 15)
103+
# The "design" key should be removed from properties once it's consumed
104+
# to recreate `ExperimentDesign`.
105+
self.assertNotIn("design", experiment._properties)
106+
self.assertEqual(experiment._properties["custom_property"], "custom_value")
107+
self.assertEqual(experiment._properties["another_property"], 42)

ax/storage/json_store/encoders.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@
7373

7474
def experiment_to_dict(experiment: Experiment) -> dict[str, Any]:
7575
"""Convert Ax experiment to a dictionary."""
76+
# Serialize ExperimentDesign into properties
77+
properties = {
78+
**experiment._properties,
79+
"design": {
80+
"concurrency_limit": experiment.design.concurrency_limit,
81+
},
82+
}
7683
return {
7784
"__type": experiment.__class__.__name__,
7885
"name": experiment._name,
@@ -87,7 +94,7 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]:
8794
"trials": experiment.trials,
8895
"is_test": experiment.is_test,
8996
"data_by_trial": experiment._data_by_trial,
90-
"properties": experiment._properties,
97+
"properties": properties,
9198
"_trial_type_to_runner": experiment._trial_type_to_runner,
9299
}
93100

ax/storage/json_store/tests/test_json_store.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,60 @@ def test_experiment_with_pruning_target_json_roundtrip(self) -> None:
16561656
).pruning_target_parameterization,
16571657
)
16581658

1659+
def test_experiment_design_json_roundtrip(self) -> None:
1660+
"""Test that ExperimentDesign is preserved through JSON serialization."""
1661+
# Setup: create experiment and set concurrency_limit
1662+
experiment = get_branin_experiment()
1663+
experiment.design.concurrency_limit = 42
1664+
1665+
# Execute: save and load experiment through JSON
1666+
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f:
1667+
save_experiment(
1668+
experiment,
1669+
f.name,
1670+
encoder_registry=CORE_ENCODER_REGISTRY,
1671+
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
1672+
)
1673+
loaded_experiment = load_experiment(
1674+
f.name,
1675+
decoder_registry=CORE_DECODER_REGISTRY,
1676+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
1677+
)
1678+
1679+
# Cleanup
1680+
os.remove(f.name)
1681+
1682+
# Assert: confirm ExperimentDesign is preserved
1683+
self.assertEqual(experiment, loaded_experiment)
1684+
self.assertEqual(loaded_experiment.design.concurrency_limit, 42)
1685+
1686+
def test_experiment_design_none_concurrency_json_roundtrip(self) -> None:
1687+
"""Test that ExperimentDesign with None concurrency_limit is preserved."""
1688+
# Setup: create experiment with default (None) concurrency_limit
1689+
experiment = get_branin_experiment()
1690+
self.assertIsNone(experiment.design.concurrency_limit)
1691+
1692+
# Execute: save and load experiment through JSON
1693+
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f:
1694+
save_experiment(
1695+
experiment,
1696+
f.name,
1697+
encoder_registry=CORE_ENCODER_REGISTRY,
1698+
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
1699+
)
1700+
loaded_experiment = load_experiment(
1701+
f.name,
1702+
decoder_registry=CORE_DECODER_REGISTRY,
1703+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
1704+
)
1705+
1706+
# Cleanup
1707+
os.remove(f.name)
1708+
1709+
# Assert: confirm ExperimentDesign is preserved with None
1710+
self.assertEqual(experiment, loaded_experiment)
1711+
self.assertIsNone(loaded_experiment.design.concurrency_limit)
1712+
16591713
def test_multi_objective_from_json_warning(self) -> None:
16601714
objectives = [get_objective()]
16611715

ax/storage/sqa_store/encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
227227
elif experiment.runner:
228228
runners.append(self.runner_to_sqa(none_throws(experiment.runner)))
229229
properties = experiment._properties.copy()
230+
# Serialize ExperimentDesign into properties
231+
properties["design"] = {
232+
"concurrency_limit": experiment.design.concurrency_limit,
233+
}
230234
if (
231235
oc := experiment.optimization_config
232236
) is not None and oc.pruning_target_parameterization is not None:

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,35 @@ def test_experiment_save_load(self) -> None:
339339
loaded_experiment = load_experiment(exp.name)
340340
self.assertEqual(loaded_experiment, exp)
341341

342+
def test_experiment_design_sqa_roundtrip(self) -> None:
343+
"""Test that ExperimentDesign is preserved through SQA serialization."""
344+
# Create experiment and set concurrency_limit
345+
experiment = get_experiment_with_batch_trial()
346+
experiment.design.concurrency_limit = 42
347+
348+
# Save and load experiment through SQA
349+
save_experiment(experiment)
350+
loaded_experiment = load_experiment(experiment.name)
351+
352+
# Verify ExperimentDesign is preserved
353+
self.assertEqual(loaded_experiment, experiment)
354+
self.assertEqual(loaded_experiment.design.concurrency_limit, 42)
355+
356+
def test_experiment_design_none_concurrency_sqa_roundtrip(self) -> None:
357+
"""Test that ExperimentDesign with None concurrency_limit is preserved."""
358+
# Create experiment with default (None) concurrency_limit
359+
experiment = get_experiment_with_batch_trial()
360+
experiment.name = "experiment_design_none_concurrency_test"
361+
self.assertIsNone(experiment.design.concurrency_limit)
362+
363+
# Save and load experiment through SQA
364+
save_experiment(experiment)
365+
loaded_experiment = load_experiment(experiment.name)
366+
367+
# Verify ExperimentDesign with None is preserved
368+
self.assertEqual(loaded_experiment, experiment)
369+
self.assertIsNone(loaded_experiment.design.concurrency_limit)
370+
342371
def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
343372
aux_experiment = Experiment(
344373
name="test_aux_exp_in_SQAStoreTest",

0 commit comments

Comments
 (0)