Skip to content

Commit 0f5214a

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Update ESS nudge logic in EarlyStopping Healthcheck (#4745)
Summary: Pull Request resolved: #4745 Differential Revision: D90035967 Privacy Context Container: L1307644
1 parent 596be05 commit 0f5214a

File tree

2 files changed

+69
-109
lines changed

2 files changed

+69
-109
lines changed

ax/analysis/healthcheck/early_stopping_healthcheck.py

Lines changed: 53 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@
2222
from ax.core.map_metric import MapMetric
2323
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
2424
from ax.early_stopping.dispatch import get_default_ess_or_none
25-
from ax.early_stopping.experiment_replay import replay_experiment
25+
from ax.early_stopping.experiment_replay import (
26+
estimate_hypothetical_early_stopping_savings,
27+
)
2628
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
27-
from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy
2829
from ax.early_stopping.utils import estimate_early_stopping_savings
2930
from ax.generation_strategy.generation_strategy import GenerationStrategy
3031
from ax.service.utils.early_stopping import get_early_stopping_metrics
3132
from pyre_extensions import none_throws, override
3233

33-
DEFAULT_MIN_SAVINGS_THRESHOLD = 0.01 # 1% threshold
34+
DEFAULT_MIN_SAVINGS_THRESHOLD = 0.1 # 10% threshold
3435
MAX_PENDING_TRIALS_DEFAULT = 5
3536
DEFAULT_EARLY_STOPPING_HEALTHCHECK_TITLE = "Early Stopping Healthcheck"
3637

@@ -87,7 +88,7 @@ def __init__(
8788
default early stopping strategy will only be used for
8889
single-objective unconstrained experiments.
8990
min_savings_threshold: Minimum savings threshold to suggest early
90-
stopping. Default is 0.01 (1% savings).
91+
stopping. Default is 0.1 (10% savings).
9192
max_pending_trials: Maximum number of pending trials for replay
9293
orchestrator. Default is 5.
9394
auto_early_stopping_config: A string for configuring automated early
@@ -407,22 +408,41 @@ def _report_early_stopping_nudge(
407408
self, experiment: Experiment
408409
) -> HealthcheckAnalysisCard:
409410
"""Check if early stopping should be suggested (nudge) by estimating
410-
hypothetical savings using replay logic."""
411-
# Get map metrics from the experiment
412-
# Note: validate_applicable_state already ensures map_metrics is non-empty
413-
map_metrics = self._get_map_metrics(experiment)
414-
415-
# Estimate hypothetical savings for compatible metrics using replay
416-
metric_to_savings = self._estimate_hypothetical_savings_with_replay(
417-
experiment=experiment, map_metrics=map_metrics
411+
hypothetical savings using replay logic.
412+
413+
Only applicable for single-objective unconstrained experiments where a
414+
default early stopping strategy is available.
415+
"""
416+
default_ess = get_default_ess_or_none(experiment=experiment)
417+
if default_ess is None:
418+
return self._create_card(
419+
subtitle=(
420+
"This experiment is multi-objective or has constraints, so "
421+
"automatic early stopping savings estimation is not available. "
422+
"If you want to use early stopping for this experiment type, "
423+
"please configure an early_stopping_strategy explicitly."
424+
),
425+
status=HealthcheckStatus.PASS,
426+
)
427+
428+
# get_default_ess_or_none returns a default ESS only for single-objective
429+
# unconstrained experiments with a valid optimization config, so we can
430+
# safely assume opt_config is not None here.
431+
opt_config = none_throws(experiment.optimization_config)
432+
433+
# Estimate hypothetical savings using replay with the default ESS
434+
metric = next(iter(opt_config.objective.metrics))
435+
savings = estimate_hypothetical_early_stopping_savings(
436+
experiment=experiment,
437+
metric=metric,
438+
max_pending_trials=self.max_pending_trials,
439+
minimize=opt_config.objective.minimize,
418440
)
419441

420-
if not metric_to_savings:
421-
# No significant savings detected
442+
if savings is None or savings < self.min_savings_threshold:
422443
return self._create_card(
423444
subtitle=(
424-
"Early stopping is not enabled. While this experiment has "
425-
"data with a progression ('step' column) we did not detect "
445+
"Early stopping is not enabled. We did not detect "
426446
"significant potential savings at this time.\n\n"
427447
"This could be because:\n"
428448
"- The experiment hasn't run enough trials yet\n"
@@ -434,15 +454,14 @@ def _report_early_stopping_nudge(
434454
)
435455

436456
# Found significant potential savings - nudge the user
437-
best_metric_name = max(metric_to_savings, key=metric_to_savings.get)
438-
best_savings = metric_to_savings[best_metric_name]
457+
savings_pct = 100 * savings
439458

440459
subtitle = (
441-
"This sweep uses metrics that are **compatible with early stopping**! "
442-
"Using early stopping could have saved you both capacity and "
443-
"optimization wall time. For example, we estimate that using early "
444-
f"stopping on the '{best_metric_name}' metric could have provided "
445-
f"{best_savings:.0f}% capacity savings, with no regression in "
460+
"This experiment uses metrics that are **compatible with early "
461+
"stopping**! Using early stopping could have saved you both capacity "
462+
"and optimization wall time. For example, we estimate that using early "
463+
f"stopping on the '{metric.name}' metric could have provided "
464+
f"{savings_pct:.0f}% capacity savings, with no regression in "
446465
"optimization performance."
447466
)
448467

@@ -451,19 +470,17 @@ def _report_early_stopping_nudge(
451470
subtitle += f" {self.nudge_additional_info}"
452471

453472
# Create detailed metrics table
454-
metric_rows = [
455-
{
456-
"Metric Name": metric_name,
457-
"Estimated Savings": f"{savings:.1f}%",
458-
}
459-
for metric_name, savings in sorted(
460-
metric_to_savings.items(), key=lambda x: x[1], reverse=True
461-
)
462-
]
463-
df = pd.DataFrame(metric_rows)
473+
df = pd.DataFrame(
474+
[
475+
{
476+
"Metric Name": metric.name,
477+
"Estimated Savings": f"{savings_pct:.1f}%",
478+
}
479+
]
480+
)
464481

465482
title = (
466-
f"{best_savings:.0f}% potential capacity savings if you turn on "
483+
f"{savings_pct:.0f}% potential capacity savings if you turn on "
467484
f"early stopping feature"
468485
)
469486

@@ -472,8 +489,8 @@ def _report_early_stopping_nudge(
472489
subtitle=subtitle,
473490
df=df,
474491
status=HealthcheckStatus.WARNING,
475-
potential_savings=best_savings,
476-
best_metric=best_metric_name,
492+
potential_savings=savings_pct,
493+
best_metric=metric.name,
477494
)
478495

479496
def _get_problem_type(self, experiment: Experiment) -> str:
@@ -504,63 +521,3 @@ def _get_map_metrics(self, experiment: Experiment) -> list[MapMetric]:
504521
reverse=True,
505522
)
506523
return map_metrics
507-
508-
def _estimate_hypothetical_savings_with_replay(
509-
self, experiment: Experiment, map_metrics: list[MapMetric]
510-
) -> dict[str, float]:
511-
"""
512-
Estimate hypothetical early stopping savings for each map metric using
513-
replay infrastructure.
514-
515-
This is the accurate method that replays the experiment with early stopping
516-
enabled to calculate actual savings.
517-
518-
Args:
519-
experiment: The experiment to analyze
520-
map_metrics: List of MapMetrics to analyze
521-
522-
Returns:
523-
Dictionary mapping metric names to estimated savings percentages
524-
(only includes metrics where savings > min_savings_threshold)
525-
"""
526-
metric_to_savings: dict[str, float] = {}
527-
528-
MAX_REPLAYS = 3
529-
MAX_REPLAY_TRIALS = 50
530-
REPLAY_NUM_POINTS_PER_CURVE = 20
531-
REPLAY_PERCENTILE_THRESHOLD = 65
532-
REPLAY_MIN_PROGRESSION_FRAC = 0.4
533-
REPLAY_MIN_CURVES = 5
534-
535-
# Limit to first few metrics to avoid expensive computation
536-
for map_metric in map_metrics[:MAX_REPLAYS]:
537-
try:
538-
# Create replayed experiment with early stopping
539-
replayed_experiment = replay_experiment(
540-
historical_experiment=experiment,
541-
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
542-
max_replay_trials=MAX_REPLAY_TRIALS,
543-
metric=map_metric,
544-
max_pending_trials=self.max_pending_trials,
545-
early_stopping_strategy=PercentileEarlyStoppingStrategy(
546-
min_curves=REPLAY_MIN_CURVES,
547-
min_progression=REPLAY_MIN_PROGRESSION_FRAC,
548-
percentile_threshold=REPLAY_PERCENTILE_THRESHOLD,
549-
normalize_progressions=True,
550-
),
551-
)
552-
553-
if replayed_experiment is not None:
554-
savings = estimate_early_stopping_savings(
555-
experiment=replayed_experiment
556-
)
557-
558-
# Only include if savings exceed threshold (> 1%)
559-
if savings > self.min_savings_threshold:
560-
metric_to_savings[map_metric.name] = 100 * savings
561-
562-
except Exception:
563-
# Skip metrics that fail replay
564-
continue
565-
566-
return metric_to_savings

ax/analysis/healthcheck/tests/test_early_stopping_healthcheck.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,20 @@ def test_early_stopping_not_enabled(self) -> None:
140140
healthcheck = EarlyStoppingAnalysis(early_stopping_strategy=None)
141141

142142
with self.subTest("no_savings_detected"):
143-
card = healthcheck.compute(experiment=self.experiment)
143+
# Mock estimate_hypothetical_early_stopping_savings to return None
144+
with patch(
145+
"ax.analysis.healthcheck.early_stopping_healthcheck"
146+
".estimate_hypothetical_early_stopping_savings",
147+
return_value=None,
148+
):
149+
card = healthcheck.compute(experiment=self.experiment)
144150
self.assertIn("Early stopping is not enabled", card.subtitle)
145151

146152
with self.subTest("potential_savings_detected"):
147-
mock_savings = {"ax_test_metric": 25.0}
148-
with patch.object(
149-
healthcheck,
150-
"_estimate_hypothetical_savings_with_replay",
153+
mock_savings = 0.25 # 25% as a decimal
154+
with patch(
155+
"ax.analysis.healthcheck.early_stopping_healthcheck"
156+
".estimate_hypothetical_early_stopping_savings",
151157
return_value=mock_savings,
152158
):
153159
card = healthcheck.compute(experiment=self.experiment)
@@ -324,11 +330,8 @@ def test_hypothetical_savings_nudge(self) -> None:
324330

325331
with self.subTest("basic_nudge"):
326332
with patch(
327-
"ax.analysis.healthcheck.early_stopping_healthcheck.replay_experiment",
328-
return_value=object(),
329-
), patch(
330333
"ax.analysis.healthcheck.early_stopping_healthcheck"
331-
".estimate_early_stopping_savings",
334+
".estimate_hypothetical_early_stopping_savings",
332335
return_value=0.25,
333336
):
334337
card = healthcheck.compute(experiment=self.experiment)
@@ -345,10 +348,10 @@ def test_hypothetical_savings_nudge(self) -> None:
345348
early_stopping_strategy=None, nudge_additional_info=nudge_info
346349
)
347350

348-
mock_savings = {"ax_test_metric": 25.0}
349-
with patch.object(
350-
healthcheck_with_info,
351-
"_estimate_hypothetical_savings_with_replay",
351+
mock_savings = 0.25
352+
with patch(
353+
"ax.analysis.healthcheck.early_stopping_healthcheck"
354+
".estimate_hypothetical_early_stopping_savings",
352355
return_value=mock_savings,
353356
):
354357
card = healthcheck_with_info.compute(experiment=self.experiment)

0 commit comments

Comments
 (0)