Skip to content

Commit 177cfd2

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Extract early stopping replay utilities to OSS (#4744)
Summary: Adds the `estimate_hypothetical_early_stopping_savings()` function to the OSS module. This function estimates potential compute savings by replaying an experiment with a default early stopping strategy. Key changes: - Added `estimate_hypothetical_early_stopping_savings()` to `experiment_replay.py` which combines `get_default_ess_or_none()`, `replay_experiment()`, and `estimate_early_stopping_savings()` into a single utility - Added constants `MAX_REPLAY_TRIALS`, `REPLAY_NUM_POINTS_PER_CURVE`, and `MAX_PENDING_TRIALS` to `experiment_replay.py` - Added optional `minimize` parameter to `replay_experiment()` to explicitly control optimization direction - Updated `ax_sweep_orchestrator.py` to use the new `estimate_hypothetical_early_stopping_savings()` function - Added unit tests for the new function in `test_experiment_replay.py` Differential Revision: D90150341
1 parent 5dccf83 commit 177cfd2

File tree

3 files changed

+238
-0
lines changed

3 files changed

+238
-0
lines changed

ax/early_stopping/experiment_replay.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from ax.core.optimization_config import OptimizationConfig
1818
from ax.core.parameter import ParameterType, RangeParameter
1919
from ax.core.search_space import SearchSpace
20+
from ax.early_stopping.dispatch import get_default_ess_or_none
2021
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
22+
from ax.early_stopping.utils import estimate_early_stopping_savings
2123
from ax.generation_strategy.generation_strategy import (
2224
GenerationStep,
2325
GenerationStrategy,
@@ -29,6 +31,11 @@
2931

3032
logger: Logger = get_logger(__name__)
3133

34+
# Constants for experiment replay
35+
MAX_REPLAY_TRIALS: int = 50
36+
REPLAY_NUM_POINTS_PER_CURVE: int = 20
37+
MAX_PENDING_TRIALS: int = 5
38+
3239

3340
def replay_experiment(
3441
historical_experiment: Experiment,
@@ -105,3 +112,54 @@ def replay_experiment(
105112
orchestrator.run_all_trials()
106113
logger.info(f"Replayed the experiment in {perf_counter() - start_time} seconds.")
107114
return experiment
115+
116+
117+
def estimate_hypothetical_early_stopping_savings(
118+
experiment: Experiment,
119+
metric: Metric,
120+
max_pending_trials: int = MAX_PENDING_TRIALS,
121+
) -> float | None:
122+
"""Estimate hypothetical early stopping savings using experiment replay.
123+
124+
This function replays the experiment with a default early stopping strategy
125+
to calculate what savings would have been achieved if early stopping were
126+
enabled.
127+
128+
Note: Returns None for multi-objective, constrained, or non-MapMetric
129+
experiments, as `get_default_ess_or_none` does not provide a default
130+
early stopping strategy for these experiment types.
131+
132+
Args:
133+
experiment: The experiment to analyze.
134+
metric: The metric to use for early stopping replay.
135+
max_pending_trials: Maximum number of pending trials for the replay
136+
orchestrator. Defaults to 5.
137+
138+
Returns:
139+
Estimated savings as a fraction (0.0 to 1.0), or None if:
140+
- No default early stopping strategy is available for this experiment
141+
(e.g., multi-objective, constrained, or non-MapMetric experiments)
142+
- The experiment replay failed
143+
"""
144+
try:
145+
default_ess = get_default_ess_or_none(experiment=experiment)
146+
if default_ess is None:
147+
return None
148+
149+
replayed_experiment = replay_experiment(
150+
historical_experiment=experiment,
151+
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
152+
max_replay_trials=MAX_REPLAY_TRIALS,
153+
metric=metric,
154+
max_pending_trials=max_pending_trials,
155+
early_stopping_strategy=default_ess,
156+
)
157+
158+
if replayed_experiment is None:
159+
return None
160+
161+
return estimate_early_stopping_savings(experiment=replayed_experiment)
162+
except Exception:
163+
# Replay can fail due to invalid experiment state (e.g., missing name,
164+
# incompatible data format) or internal errors during orchestration.
165+
return None
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from unittest.mock import MagicMock, patch
10+
11+
from ax.early_stopping.experiment_replay import (
12+
estimate_hypothetical_early_stopping_savings,
13+
)
14+
from ax.utils.common.testutils import TestCase
15+
from ax.utils.testing.core_stubs import (
16+
get_branin_experiment,
17+
get_branin_experiment_with_timestamp_map_metric,
18+
)
19+
from pyre_extensions import none_throws
20+
21+
22+
class TestEstimateHypotheticalEarlyStoppingSavings(TestCase):
23+
def test_returns_none_for_non_map_metric_experiment(self) -> None:
24+
"""Test that None is returned when experiment has no MapMetric."""
25+
exp = get_branin_experiment(has_optimization_config=True)
26+
metric = none_throws(exp.optimization_config).objective.metric
27+
28+
result = estimate_hypothetical_early_stopping_savings(
29+
experiment=exp,
30+
metric=metric,
31+
)
32+
33+
self.assertIsNone(result)
34+
35+
def test_returns_none_for_multi_objective(self) -> None:
36+
"""Test that None is returned for multi-objective experiments."""
37+
exp = get_branin_experiment_with_timestamp_map_metric(multi_objective=True)
38+
# Use first metric from optimization config for multi-objective
39+
metric = list(none_throws(exp.optimization_config).metrics.values())[0]
40+
41+
result = estimate_hypothetical_early_stopping_savings(
42+
experiment=exp,
43+
metric=metric,
44+
)
45+
46+
self.assertIsNone(result)
47+
48+
def test_returns_none_for_constrained_experiment(self) -> None:
49+
"""Test that None is returned for experiments with outcome constraints."""
50+
exp = get_branin_experiment_with_timestamp_map_metric(
51+
with_outcome_constraint=True
52+
)
53+
metric = none_throws(exp.optimization_config).objective.metric
54+
55+
result = estimate_hypothetical_early_stopping_savings(
56+
experiment=exp,
57+
metric=metric,
58+
)
59+
60+
self.assertIsNone(result)
61+
62+
@patch("ax.early_stopping.experiment_replay.replay_experiment")
63+
def test_returns_none_when_replay_fails(
64+
self, mock_replay_experiment: MagicMock
65+
) -> None:
66+
"""Test that None is returned when replay_experiment fails."""
67+
exp = get_branin_experiment_with_timestamp_map_metric()
68+
metric = none_throws(exp.optimization_config).objective.metric
69+
mock_replay_experiment.return_value = None
70+
71+
result = estimate_hypothetical_early_stopping_savings(
72+
experiment=exp,
73+
metric=metric,
74+
)
75+
76+
self.assertIsNone(result)
77+
mock_replay_experiment.assert_called_once()
78+
79+
@patch("ax.early_stopping.experiment_replay.estimate_early_stopping_savings")
80+
@patch("ax.early_stopping.experiment_replay.replay_experiment")
81+
def test_returns_savings_on_successful_replay(
82+
self,
83+
mock_replay_experiment: MagicMock,
84+
mock_estimate_savings: MagicMock,
85+
) -> None:
86+
"""Test that savings are returned when replay succeeds."""
87+
exp = get_branin_experiment_with_timestamp_map_metric()
88+
metric = none_throws(exp.optimization_config).objective.metric
89+
mock_replayed_exp = MagicMock()
90+
mock_replay_experiment.return_value = mock_replayed_exp
91+
mock_estimate_savings.return_value = 0.25
92+
93+
result = estimate_hypothetical_early_stopping_savings(
94+
experiment=exp,
95+
metric=metric,
96+
)
97+
98+
self.assertEqual(result, 0.25)
99+
mock_estimate_savings.assert_called_once_with(experiment=mock_replayed_exp)
100+
101+
@patch("ax.early_stopping.experiment_replay.replay_experiment")
102+
def test_returns_none_when_exception_raised(
103+
self, mock_replay_experiment: MagicMock
104+
) -> None:
105+
"""Test that None is returned when replay fails due to invalid experiment
106+
state (e.g., missing name) or internal orchestration errors.
107+
"""
108+
exp = get_branin_experiment_with_timestamp_map_metric()
109+
metric = none_throws(exp.optimization_config).objective.metric
110+
mock_replay_experiment.side_effect = ValueError("Experiment's name is None.")
111+
112+
result = estimate_hypothetical_early_stopping_savings(
113+
experiment=exp,
114+
metric=metric,
115+
)
116+
117+
self.assertIsNone(result)
118+
mock_replay_experiment.assert_called_once()

ax/early_stopping/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,68 @@
1919

2020
logger: Logger = get_logger(__name__)
2121

22+
# Early stopping message constants for use in analysis and reporting
23+
EARLY_STOPPING_STATUS_MSG = (
24+
"Throughout this experiment, {n_stopped} trials were early stopped, out "
25+
"of a total of {n_ran} trials. "
26+
)
27+
28+
EARLY_STOPPING_SAVINGS_TITLE = "Capacity savings due to early stopping"
29+
30+
EARLY_STOPPING_SAVINGS_MSG = (
31+
"The capacity savings (computed using {map_key}) are estimated to be "
32+
"{savings:.0f}%."
33+
)
34+
35+
EARLY_STOPPING_SAVINGS_TBD = (
36+
"Capacity savings are not yet available. Either no trials have been early "
37+
"stopped, or no trials have completed (which is required to estimate "
38+
"savings). Check back once more trials are completed and/or early stopped."
39+
)
40+
41+
EARLY_STOPPING_NUDGE_MSG = (
42+
"This sweep uses metrics that are **compatible with early stopping**! "
43+
"Using early stopping could have saved you both capacity and optimization "
44+
"wall time. For example, we estimate that using early stopping on the "
45+
"'{metric_name}' metric could have provided {savings:.0f}% capacity "
46+
"savings, with no regression in optimization performance."
47+
)
48+
49+
EARLY_STOPPING_NUDGE_TITLE = (
50+
"{savings:.0f}% potential capacity savings if you turn on " "early stopping feature"
51+
)
52+
53+
54+
def format_early_stopping_savings_message(
55+
n_stopped: int,
56+
n_ran: int,
57+
savings: float,
58+
) -> str:
59+
"""Format a message describing early stopping status and savings.
60+
61+
This function consolidates the common logic used by both AxSweep and the
62+
early stopping healthcheck to format early stopping status messages.
63+
64+
Args:
65+
n_stopped: Number of trials that were early stopped.
66+
n_ran: Total number of trials that ran (stopped + completed + failed + running).
67+
savings: Resource savings as a fraction (0.0 to 1.0). For example, 0.11
68+
indicates 11% savings.
69+
70+
Returns:
71+
A formatted message string describing the early stopping status and
72+
either the estimated savings percentage or a note that savings are
73+
not yet available.
74+
"""
75+
msg = EARLY_STOPPING_STATUS_MSG.format(n_stopped=n_stopped, n_ran=n_ran)
76+
77+
if savings > 0:
78+
msg += EARLY_STOPPING_SAVINGS_MSG.format(map_key=MAP_KEY, savings=savings * 100)
79+
else:
80+
msg += EARLY_STOPPING_SAVINGS_TBD
81+
82+
return msg
83+
2284

2385
def _is_worse(a: Any, b: Any, minimize: bool) -> Any:
2486
"""Determine if value `a` is worse than value `b` based on optimization direction.

0 commit comments

Comments
 (0)