Skip to content

Commit 530e212

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Extract early stopping replay utilities to OSS (#4744)
Summary: Adds the `estimate_hypothetical_early_stopping_savings()` function to the OSS module. This function estimates potential compute savings by replaying an experiment with a default early stopping strategy. Differential Revision: D90150341
1 parent 2b16f4e commit 530e212

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

ax/early_stopping/experiment_replay.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from ax.core.optimization_config import OptimizationConfig
1818
from ax.core.parameter import ParameterType, RangeParameter
1919
from ax.core.search_space import SearchSpace
20+
from ax.early_stopping.dispatch import get_default_ess_or_none
2021
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
22+
from ax.early_stopping.utils import estimate_early_stopping_savings
2123
from ax.generation_strategy.generation_strategy import (
2224
GenerationStep,
2325
GenerationStrategy,
@@ -29,6 +31,11 @@
2931

3032
logger: Logger = get_logger(__name__)
3133

34+
# Constants for experiment replay
35+
MAX_REPLAY_TRIALS: int = 50
36+
REPLAY_NUM_POINTS_PER_CURVE: int = 20
37+
MAX_PENDING_TRIALS: int = 5
38+
3239

3340
def replay_experiment(
3441
historical_experiment: Experiment,
@@ -105,3 +112,56 @@ def replay_experiment(
105112
orchestrator.run_all_trials()
106113
logger.info(f"Replayed the experiment in {perf_counter() - start_time} seconds.")
107114
return experiment
115+
116+
117+
def estimate_hypothetical_early_stopping_savings(
118+
experiment: Experiment,
119+
metric: Metric,
120+
max_pending_trials: int = MAX_PENDING_TRIALS,
121+
) -> float | None:
122+
"""Estimate hypothetical early stopping savings using experiment replay.
123+
124+
This function replays the experiment with a default early stopping strategy
125+
to calculate what savings would have been achieved if early stopping were
126+
enabled.
127+
128+
Args:
129+
experiment: The experiment to analyze.
130+
metric: The metric to use for early stopping replay.
131+
max_pending_trials: Maximum number of pending trials for the replay
132+
orchestrator. Defaults to 5.
133+
134+
Returns:
135+
Estimated savings as a fraction (0.0 to 1.0), or None if:
136+
- No default early stopping strategy is available for this experiment
137+
(e.g., multi-objective, constrained, or non-MapMetric experiments)
138+
- The experiment replay failed
139+
"""
140+
try:
141+
default_ess = get_default_ess_or_none(experiment=experiment)
142+
if default_ess is None:
143+
logger.info(
144+
"No default early stopping strategy available (multi-objective, "
145+
"constrained, or non-MapMetric experiment)."
146+
)
147+
return None
148+
149+
replayed_experiment = replay_experiment(
150+
historical_experiment=experiment,
151+
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
152+
max_replay_trials=MAX_REPLAY_TRIALS,
153+
metric=metric,
154+
max_pending_trials=max_pending_trials,
155+
early_stopping_strategy=default_ess,
156+
)
157+
158+
if replayed_experiment is None:
159+
logger.info("Experiment data does not have progression data for replay.")
160+
return None
161+
162+
return estimate_early_stopping_savings(experiment=replayed_experiment)
163+
except Exception as e:
164+
# Replay can fail due to invalid experiment state (e.g., missing name,
165+
# incompatible data format) or internal errors during orchestration.
166+
logger.info(f"Experiment replay failed with exception: {e}")
167+
return None
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from unittest.mock import patch
10+
11+
from ax.early_stopping.experiment_replay import (
12+
estimate_hypothetical_early_stopping_savings,
13+
logger,
14+
)
15+
from ax.utils.common.testutils import TestCase
16+
from ax.utils.testing.core_stubs import (
17+
get_branin_experiment,
18+
get_branin_experiment_with_timestamp_map_metric,
19+
)
20+
from pyre_extensions import none_throws
21+
22+
23+
class TestEstimateHypotheticalEss(TestCase):
24+
def setUp(self) -> None:
25+
super().setUp()
26+
# Experiment with MapMetric for tests that need a valid default ESS.
27+
self.exp = get_branin_experiment_with_timestamp_map_metric()
28+
self.metric = none_throws(self.exp.optimization_config).objective.metric
29+
30+
def test_estimate_hypothetical_ess_no_default_strategy(self) -> None:
31+
"""Test that None is returned when no default ESS is available."""
32+
# Non-MapMetric experiment has no default ESS.
33+
exp = get_branin_experiment(has_optimization_config=True)
34+
metric = none_throws(exp.optimization_config).objective.metric
35+
36+
with patch.object(logger, "info") as mock_info:
37+
result = estimate_hypothetical_early_stopping_savings(
38+
experiment=exp,
39+
metric=metric,
40+
)
41+
42+
self.assertIsNone(result)
43+
mock_info.assert_called_once_with(
44+
"No default early stopping strategy available (multi-objective, "
45+
"constrained, or non-MapMetric experiment)."
46+
)
47+
48+
def test_estimate_hypothetical_ess_no_progression_data(self) -> None:
49+
"""Test that None is returned when experiment has no progression data."""
50+
with (
51+
patch(
52+
"ax.early_stopping.experiment_replay.replay_experiment",
53+
return_value=None,
54+
),
55+
patch.object(logger, "info") as mock_info,
56+
):
57+
result = estimate_hypothetical_early_stopping_savings(
58+
experiment=self.exp,
59+
metric=self.metric,
60+
)
61+
62+
self.assertIsNone(result)
63+
mock_info.assert_called_once_with(
64+
"Experiment data does not have progression data for replay."
65+
)
66+
67+
def test_estimate_hypothetical_ess_success(self) -> None:
68+
"""Test that savings are returned when replay succeeds."""
69+
with (
70+
patch(
71+
"ax.early_stopping.experiment_replay.replay_experiment",
72+
) as mock_replay,
73+
patch(
74+
"ax.early_stopping.experiment_replay.estimate_early_stopping_savings",
75+
return_value=0.25,
76+
) as mock_estimate,
77+
):
78+
result = estimate_hypothetical_early_stopping_savings(
79+
experiment=self.exp,
80+
metric=self.metric,
81+
)
82+
83+
self.assertEqual(result, 0.25)
84+
mock_replay.assert_called_once()
85+
mock_estimate.assert_called_once()
86+
87+
def test_estimate_hypothetical_ess_exception(self) -> None:
88+
"""Test that None is returned when replay raises an exception."""
89+
with (
90+
patch(
91+
"ax.early_stopping.experiment_replay.replay_experiment",
92+
side_effect=ValueError("Experiment's name is None."),
93+
),
94+
patch.object(logger, "info") as mock_info,
95+
):
96+
result = estimate_hypothetical_early_stopping_savings(
97+
experiment=self.exp,
98+
metric=self.metric,
99+
)
100+
101+
self.assertIsNone(result)
102+
mock_info.assert_called_once_with(
103+
"Experiment replay failed with exception: Experiment's name is None."
104+
)

0 commit comments

Comments
 (0)