Skip to content

Commit 4aa249c

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Extract early stopping replay utilities to OSS (#4744)
Summary: Adds the `estimate_hypothetical_early_stopping_savings()` function to the OSS module. This function estimates potential compute savings by replaying an experiment with a default early stopping strategy. Reviewed By: bernardbeckerman Differential Revision: D90150341
1 parent c8a5280 commit 4aa249c

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

ax/early_stopping/experiment_replay.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from ax.core.optimization_config import OptimizationConfig
1818
from ax.core.parameter import ParameterType, RangeParameter
1919
from ax.core.search_space import SearchSpace
20+
from ax.early_stopping.dispatch import get_default_ess_or_none
2021
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
22+
from ax.early_stopping.utils import estimate_early_stopping_savings
23+
from ax.exceptions.core import UnsupportedError
2124
from ax.generation_strategy.generation_strategy import (
2225
GenerationStep,
2326
GenerationStrategy,
@@ -29,6 +32,12 @@
2932

3033
logger: Logger = get_logger(__name__)
3134

35+
# Constants for experiment replay
36+
MAX_REPLAY_TRIALS: int = 50
37+
REPLAY_NUM_POINTS_PER_CURVE: int = 20
38+
MAX_PENDING_TRIALS: int = 5
39+
MIN_SAVINGS_THRESHOLD: float = 0.1 # 10% threshold
40+
3241

3342
def replay_experiment(
3443
historical_experiment: Experiment,
@@ -105,3 +114,55 @@ def replay_experiment(
105114
orchestrator.run_all_trials()
106115
logger.info(f"Replayed the experiment in {perf_counter() - start_time} seconds.")
107116
return experiment
117+
118+
119+
def estimate_hypothetical_early_stopping_savings(
120+
experiment: Experiment,
121+
metric: Metric,
122+
max_pending_trials: int = MAX_PENDING_TRIALS,
123+
) -> float:
124+
"""Estimate hypothetical early stopping savings using experiment replay.
125+
126+
This function replays the experiment with a default early stopping strategy
127+
to calculate what savings would have been achieved if early stopping were
128+
enabled.
129+
130+
Args:
131+
experiment: The experiment to analyze.
132+
metric: The metric to use for early stopping replay.
133+
max_pending_trials: Maximum number of pending trials for the replay
134+
orchestrator. Defaults to 5.
135+
136+
Returns:
137+
Estimated savings as a fraction (0.0 to 1.0).
138+
139+
Raises:
140+
UnsupportedError: If early stopping savings cannot be estimated.
141+
This can happen when:
142+
- No default early stopping strategy is available for this experiment
143+
(e.g., multi-objective, constrained, or non-MapMetric experiments)
144+
- The experiment data does not have progression data for replay
145+
- The experiment replay fails due to invalid experiment state
146+
"""
147+
default_ess = get_default_ess_or_none(experiment=experiment)
148+
if default_ess is None:
149+
raise UnsupportedError(
150+
"No default early stopping strategy available (multi-objective, "
151+
"constrained, or non-MapMetric experiment)."
152+
)
153+
154+
replayed_experiment = replay_experiment(
155+
historical_experiment=experiment,
156+
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
157+
max_replay_trials=MAX_REPLAY_TRIALS,
158+
metric=metric,
159+
max_pending_trials=max_pending_trials,
160+
early_stopping_strategy=default_ess,
161+
)
162+
163+
if replayed_experiment is None:
164+
raise UnsupportedError(
165+
"Experiment data does not have progression data for replay."
166+
)
167+
168+
return estimate_early_stopping_savings(experiment=replayed_experiment)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from unittest.mock import patch
10+
11+
from ax.early_stopping.experiment_replay import (
12+
estimate_hypothetical_early_stopping_savings,
13+
)
14+
from ax.exceptions.core import UnsupportedError
15+
from ax.utils.common.testutils import TestCase
16+
from ax.utils.testing.core_stubs import (
17+
get_branin_experiment,
18+
get_branin_experiment_with_timestamp_map_metric,
19+
)
20+
from pyre_extensions import none_throws
21+
22+
23+
class TestEstimateHypotheticalEss(TestCase):
24+
def setUp(self) -> None:
25+
super().setUp()
26+
# Experiment with MapMetric for tests that need a valid default ESS.
27+
self.exp = get_branin_experiment_with_timestamp_map_metric()
28+
self.metric = none_throws(self.exp.optimization_config).objective.metric
29+
30+
def test_estimate_hypothetical_ess_no_default_strategy(self) -> None:
31+
"""Test that UnsupportedError is raised when no default ESS is available."""
32+
# Non-MapMetric experiment has no default ESS.
33+
exp = get_branin_experiment(has_optimization_config=True)
34+
metric = none_throws(exp.optimization_config).objective.metric
35+
36+
with self.assertRaises(UnsupportedError) as e:
37+
estimate_hypothetical_early_stopping_savings(
38+
experiment=exp,
39+
metric=metric,
40+
)
41+
42+
self.assertIn(
43+
"No default early stopping strategy available",
44+
str(e.exception),
45+
)
46+
47+
def test_estimate_hypothetical_ess_no_progression_data(self) -> None:
48+
"""Test that UnsupportedError is raised when experiment has no progression
49+
data."""
50+
with patch(
51+
"ax.early_stopping.experiment_replay.replay_experiment",
52+
return_value=None,
53+
):
54+
with self.assertRaises(UnsupportedError) as e:
55+
estimate_hypothetical_early_stopping_savings(
56+
experiment=self.exp,
57+
metric=self.metric,
58+
)
59+
60+
self.assertIn(
61+
"Experiment data does not have progression data for replay",
62+
str(e.exception),
63+
)
64+
65+
def test_estimate_hypothetical_ess_success(self) -> None:
66+
"""Test that savings are returned when replay succeeds."""
67+
with (
68+
patch(
69+
"ax.early_stopping.experiment_replay.replay_experiment",
70+
) as mock_replay,
71+
patch(
72+
"ax.early_stopping.experiment_replay.estimate_early_stopping_savings",
73+
return_value=0.25,
74+
) as mock_estimate,
75+
):
76+
result = estimate_hypothetical_early_stopping_savings(
77+
experiment=self.exp,
78+
metric=self.metric,
79+
)
80+
81+
self.assertEqual(result, 0.25)
82+
mock_replay.assert_called_once()
83+
mock_estimate.assert_called_once()
84+
85+
def test_estimate_hypothetical_ess_exception(self) -> None:
86+
"""Test that exceptions from replay propagate to the caller."""
87+
with patch(
88+
"ax.early_stopping.experiment_replay.replay_experiment",
89+
side_effect=ValueError("Experiment's name is None."),
90+
):
91+
with self.assertRaises(ValueError) as e:
92+
estimate_hypothetical_early_stopping_savings(
93+
experiment=self.exp,
94+
metric=self.metric,
95+
)
96+
97+
self.assertIn("Experiment's name is None.", str(e.exception))

0 commit comments

Comments
 (0)