Skip to content

Commit 2223d45

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Update ESS nudge logic in EarlyStopping Healthcheck (#4745)
Summary: Pull Request resolved: #4745 Differential Revision: D90035967 Privacy Context Container: L1307644
1 parent 3294d10 commit 2223d45

File tree

2 files changed

+82
-106
lines changed

2 files changed

+82
-106
lines changed

ax/analysis/healthcheck/early_stopping_healthcheck.py

Lines changed: 48 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from ax.core.map_metric import MapMetric
2222
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
2323
from ax.early_stopping.dispatch import get_default_ess_or_none
24-
from ax.early_stopping.experiment_replay import replay_experiment
24+
from ax.early_stopping.experiment_replay import (
25+
estimate_hypothetical_early_stopping_savings,
26+
)
2527
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
26-
from ax.early_stopping.strategies.percentile import PercentileEarlyStoppingStrategy
2728
from ax.early_stopping.utils import (
2829
EARLY_STOPPING_NUDGE_MSG,
2930
EARLY_STOPPING_NUDGE_TITLE,
@@ -35,7 +36,7 @@
3536
from ax.service.utils.early_stopping import get_early_stopping_metrics
3637
from pyre_extensions import none_throws, override
3738

38-
DEFAULT_MIN_SAVINGS_THRESHOLD = 0.01 # 1% threshold
39+
DEFAULT_MIN_SAVINGS_THRESHOLD = 0.1 # 10% threshold
3940
MAX_PENDING_TRIALS_DEFAULT = 5
4041
DEFAULT_EARLY_STOPPING_HEALTHCHECK_TITLE = "Early Stopping Healthcheck"
4142

@@ -92,7 +93,7 @@ def __init__(
9293
default early stopping strategy will only be used for
9394
single-objective unconstrained experiments.
9495
min_savings_threshold: Minimum savings threshold to suggest early
95-
stopping. Default is 0.01 (1% savings).
96+
stopping. Default is 0.1 (10% savings).
9697
max_pending_trials: Maximum number of pending trials for replay
9798
orchestrator. Default is 5.
9899
auto_early_stopping_config: A string for configuring automated early
@@ -396,22 +397,40 @@ def _report_early_stopping_nudge(
396397
self, experiment: Experiment
397398
) -> HealthcheckAnalysisCard:
398399
"""Check if early stopping should be suggested (nudge) by estimating
399-
hypothetical savings using replay logic."""
400-
# Get map metrics from the experiment
401-
# Note: validate_applicable_state already ensures map_metrics is non-empty
402-
map_metrics = self._get_map_metrics(experiment)
403-
404-
# Estimate hypothetical savings for compatible metrics using replay
405-
metric_to_savings = self._estimate_hypothetical_savings_with_replay(
406-
experiment=experiment, map_metrics=map_metrics
400+
hypothetical savings using replay logic.
401+
402+
Only applicable for single-objective unconstrained experiments where a
403+
default early stopping strategy is available.
404+
"""
405+
opt_config = none_throws(experiment.optimization_config)
406+
metric = next(iter(opt_config.objective.metrics))
407+
savings = estimate_hypothetical_early_stopping_savings(
408+
experiment=experiment,
409+
metric=metric,
410+
max_pending_trials=self.max_pending_trials,
407411
)
408412

409-
if not metric_to_savings:
410-
# No significant savings detected
413+
if savings is None:
414+
# savings is None when estimate_hypothetical_early_stopping_savings
415+
# cannot compute savings. This happens for:
416+
# - Multi-objective or constrained experiments (no default ESS)
417+
# - Experiments without MapMetric data
418+
# - Experiment replay failures
419+
problem_type = self._get_problem_type(experiment)
411420
return self._create_card(
412421
subtitle=(
413-
"Early stopping is not enabled. While this experiment has "
414-
"data with a progression ('step' column) we did not detect "
422+
f"Early stopping is not enabled. Automatic early stopping "
423+
f"savings estimation is not available for this experiment "
424+
f"({problem_type}). If you want to use early stopping, "
425+
f"please configure an early_stopping_strategy explicitly."
426+
),
427+
status=HealthcheckStatus.PASS,
428+
)
429+
430+
if savings < self.min_savings_threshold:
431+
return self._create_card(
432+
subtitle=(
433+
"Early stopping is not enabled. We did not detect "
415434
"significant potential savings at this time.\n\n"
416435
"This could be because:\n"
417436
"- The experiment hasn't run enough trials yet\n"
@@ -423,38 +442,35 @@ def _report_early_stopping_nudge(
423442
)
424443

425444
# Found significant potential savings - nudge the user
426-
best_metric_name = max(metric_to_savings, key=metric_to_savings.get)
427-
best_savings = metric_to_savings[best_metric_name]
445+
savings_pct = 100 * savings
428446

429447
subtitle = EARLY_STOPPING_NUDGE_MSG.format(
430-
metric_name=best_metric_name, savings=best_savings
448+
metric_name=metric.name, savings=savings_pct
431449
)
432450

433451
# Append additional info if provided
434452
if self.nudge_additional_info:
435453
subtitle += f" {self.nudge_additional_info}"
436454

437455
# Create detailed metrics table
438-
metric_rows = [
439-
{
440-
"Metric Name": metric_name,
441-
"Estimated Savings": f"{savings:.1f}%",
442-
}
443-
for metric_name, savings in sorted(
444-
metric_to_savings.items(), key=lambda x: x[1], reverse=True
445-
)
446-
]
447-
df = pd.DataFrame(metric_rows)
456+
df = pd.DataFrame(
457+
[
458+
{
459+
"Metric Name": metric.name,
460+
"Estimated Savings": f"{savings_pct:.1f}%",
461+
}
462+
]
463+
)
448464

449-
title = EARLY_STOPPING_NUDGE_TITLE.format(savings=best_savings)
465+
title = EARLY_STOPPING_NUDGE_TITLE.format(savings=savings_pct)
450466

451467
return self._create_card(
452468
title=title,
453469
subtitle=subtitle,
454470
df=df,
455471
status=HealthcheckStatus.WARNING,
456-
potential_savings=best_savings,
457-
best_metric=best_metric_name,
472+
potential_savings=savings_pct,
473+
best_metric=metric.name,
458474
)
459475

460476
def _get_problem_type(self, experiment: Experiment) -> str:
@@ -485,63 +501,3 @@ def _get_map_metrics(self, experiment: Experiment) -> list[MapMetric]:
485501
reverse=True,
486502
)
487503
return map_metrics
488-
489-
def _estimate_hypothetical_savings_with_replay(
490-
self, experiment: Experiment, map_metrics: list[MapMetric]
491-
) -> dict[str, float]:
492-
"""
493-
Estimate hypothetical early stopping savings for each map metric using
494-
replay infrastructure.
495-
496-
This is the accurate method that replays the experiment with early stopping
497-
enabled to calculate actual savings.
498-
499-
Args:
500-
experiment: The experiment to analyze
501-
map_metrics: List of MapMetrics to analyze
502-
503-
Returns:
504-
Dictionary mapping metric names to estimated savings percentages
505-
(only includes metrics where savings > min_savings_threshold)
506-
"""
507-
metric_to_savings: dict[str, float] = {}
508-
509-
MAX_REPLAYS = 3
510-
MAX_REPLAY_TRIALS = 50
511-
REPLAY_NUM_POINTS_PER_CURVE = 20
512-
REPLAY_PERCENTILE_THRESHOLD = 65
513-
REPLAY_MIN_PROGRESSION_FRAC = 0.4
514-
REPLAY_MIN_CURVES = 5
515-
516-
# Limit to first few metrics to avoid expensive computation
517-
for map_metric in map_metrics[:MAX_REPLAYS]:
518-
try:
519-
# Create replayed experiment with early stopping
520-
replayed_experiment = replay_experiment(
521-
historical_experiment=experiment,
522-
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
523-
max_replay_trials=MAX_REPLAY_TRIALS,
524-
metric=map_metric,
525-
max_pending_trials=self.max_pending_trials,
526-
early_stopping_strategy=PercentileEarlyStoppingStrategy(
527-
min_curves=REPLAY_MIN_CURVES,
528-
min_progression=REPLAY_MIN_PROGRESSION_FRAC,
529-
percentile_threshold=REPLAY_PERCENTILE_THRESHOLD,
530-
normalize_progressions=True,
531-
),
532-
)
533-
534-
if replayed_experiment is not None:
535-
savings = estimate_early_stopping_savings(
536-
experiment=replayed_experiment
537-
)
538-
539-
# Only include if savings exceed threshold (> 1%)
540-
if savings > self.min_savings_threshold:
541-
metric_to_savings[map_metric.name] = 100 * savings
542-
543-
except Exception:
544-
# Skip metrics that fail replay
545-
continue
546-
547-
return metric_to_savings

ax/analysis/healthcheck/tests/test_early_stopping_healthcheck.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,38 @@ def test_early_stopping_not_enabled(self) -> None:
139139
"""Test behavior when early stopping is not enabled."""
140140
healthcheck = EarlyStoppingAnalysis(early_stopping_strategy=None)
141141

142-
with self.subTest("no_savings_detected"):
143-
card = healthcheck.compute(experiment=self.experiment)
142+
with self.subTest("no_savings_available"):
143+
# Mock estimate_hypothetical_early_stopping_savings to return None
144+
# This happens for MOO/constrained experiments, non-MapMetric data,
145+
# or replay failures
146+
with patch(
147+
"ax.analysis.healthcheck.early_stopping_healthcheck"
148+
".estimate_hypothetical_early_stopping_savings",
149+
return_value=None,
150+
):
151+
card = healthcheck.compute(experiment=self.experiment)
152+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
153+
self.assertIn("Early stopping is not enabled", card.subtitle)
154+
self.assertIn("Automatic early stopping savings estimation", card.subtitle)
155+
156+
with self.subTest("low_savings_detected"):
157+
# Mock low savings below threshold (default 10%)
158+
mock_savings = 0.05 # 5% savings
159+
with patch(
160+
"ax.analysis.healthcheck.early_stopping_healthcheck"
161+
".estimate_hypothetical_early_stopping_savings",
162+
return_value=mock_savings,
163+
):
164+
card = healthcheck.compute(experiment=self.experiment)
165+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
144166
self.assertIn("Early stopping is not enabled", card.subtitle)
167+
self.assertIn("did not detect significant potential savings", card.subtitle)
145168

146169
with self.subTest("potential_savings_detected"):
147-
mock_savings = {"ax_test_metric": 25.0}
148-
with patch.object(
149-
healthcheck,
150-
"_estimate_hypothetical_savings_with_replay",
170+
mock_savings = 0.25 # 25% as a decimal
171+
with patch(
172+
"ax.analysis.healthcheck.early_stopping_healthcheck"
173+
".estimate_hypothetical_early_stopping_savings",
151174
return_value=mock_savings,
152175
):
153176
card = healthcheck.compute(experiment=self.experiment)
@@ -324,11 +347,8 @@ def test_hypothetical_savings_nudge(self) -> None:
324347

325348
with self.subTest("basic_nudge"):
326349
with patch(
327-
"ax.analysis.healthcheck.early_stopping_healthcheck.replay_experiment",
328-
return_value=object(),
329-
), patch(
330350
"ax.analysis.healthcheck.early_stopping_healthcheck"
331-
".estimate_early_stopping_savings",
351+
".estimate_hypothetical_early_stopping_savings",
332352
return_value=0.25,
333353
):
334354
card = healthcheck.compute(experiment=self.experiment)
@@ -345,10 +365,10 @@ def test_hypothetical_savings_nudge(self) -> None:
345365
early_stopping_strategy=None, nudge_additional_info=nudge_info
346366
)
347367

348-
mock_savings = {"ax_test_metric": 25.0}
349-
with patch.object(
350-
healthcheck_with_info,
351-
"_estimate_hypothetical_savings_with_replay",
368+
mock_savings = 0.25
369+
with patch(
370+
"ax.analysis.healthcheck.early_stopping_healthcheck"
371+
".estimate_hypothetical_early_stopping_savings",
352372
return_value=mock_savings,
353373
):
354374
card = healthcheck_with_info.compute(experiment=self.experiment)

0 commit comments

Comments
 (0)