Skip to content

Commit 4624c99

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Decouple summarize_ax_optimization_complexity from OrchestratorOptions (#4706)
Summary: This change decouples `summarize_ax_optimization_complexity` from requiring an `OrchestratorOptions` instance by accepting individual optional fields instead. Differential Revision: D89778530
1 parent 495906b commit 4624c99

File tree

4 files changed

+53
-42
lines changed

4 files changed

+53
-42
lines changed

ax/analysis/healthcheck/complexity_rating.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,14 @@ def compute(
138138
options = none_throws(self.options)
139139
optimization_summary = summarize_ax_optimization_complexity(
140140
experiment=experiment,
141-
options=options,
142141
tier_metadata=self.tier_metadata,
142+
early_stopping_strategy=options.early_stopping_strategy,
143+
global_stopping_strategy=options.global_stopping_strategy,
144+
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
145+
max_pending_trials=options.max_pending_trials,
146+
min_failed_trials_for_failure_rate_check=(
147+
options.min_failed_trials_for_failure_rate_check
148+
),
143149
)
144150

145151
# Determine tier

ax/service/utils/orchestrator_options.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
1414
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
1515

16+
# Default values for OrchestratorOptions fields
17+
DEFAULT_MAX_PENDING_TRIALS: int = 10
18+
DEFAULT_TOLERATED_TRIAL_FAILURE_RATE: float = 0.5
19+
DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK: int = 5
20+
1621

1722
class TrialType(Enum):
1823
TRIAL = 0
@@ -125,12 +130,14 @@ class OrchestratorOptions:
125130
Default to False.
126131
"""
127132

128-
max_pending_trials: int = 10
133+
max_pending_trials: int = DEFAULT_MAX_PENDING_TRIALS
129134
trial_type: TrialType = TrialType.TRIAL
130135
batch_size: int | None = None
131136
total_trials: int | None = None
132-
tolerated_trial_failure_rate: float = 0.5
133-
min_failed_trials_for_failure_rate_check: int = 5
137+
tolerated_trial_failure_rate: float = DEFAULT_TOLERATED_TRIAL_FAILURE_RATE
138+
min_failed_trials_for_failure_rate_check: int = (
139+
DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK
140+
)
134141
log_filepath: str | None = None
135142
logging_level: int = INFO
136143
ttl_seconds_for_trials: int | None = None

ax/utils/common/complexity_utils.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
from ax.adapter.adapter_utils import can_map_to_binary, is_unordered_choice
1313
from ax.core.experiment import Experiment
1414
from ax.core.objective import MultiObjective
15+
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
1516
from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError
16-
from ax.service.orchestrator import OrchestratorOptions
17+
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
18+
from ax.service.utils.orchestrator_options import (
19+
DEFAULT_MAX_PENDING_TRIALS,
20+
DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK,
21+
DEFAULT_TOLERATED_TRIAL_FAILURE_RATE,
22+
)
1723

1824
STANDARD_TIER_MESSAGE = """This experiment is in tier 'Standard'.
1925
@@ -141,8 +147,14 @@ class OptimizationSummary:
141147

142148
def summarize_ax_optimization_complexity(
143149
experiment: Experiment,
144-
options: OrchestratorOptions,
145150
tier_metadata: dict[str, Any],
151+
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
152+
global_stopping_strategy: BaseGlobalStoppingStrategy | None = None,
153+
tolerated_trial_failure_rate: float | None = DEFAULT_TOLERATED_TRIAL_FAILURE_RATE,
154+
max_pending_trials: int | None = DEFAULT_MAX_PENDING_TRIALS,
155+
min_failed_trials_for_failure_rate_check: int | None = (
156+
DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK
157+
),
146158
) -> OptimizationSummary:
147159
"""Summarize the experiment's optimization complexity.
148160
@@ -151,11 +163,25 @@ def summarize_ax_optimization_complexity(
151163
152164
Args:
153165
experiment: The Ax Experiment.
154-
options: The orchestrator options.
155-
tier_metadata: tier-related meta-data from the orchestrator.
166+
tier_metadata: Tier-related metadata from the orchestrator. Supported keys:
167+
- 'user_supplied_max_trials': Maximum number of trials.
168+
- 'uses_standard_api': Whether high-level configs are used (as
169+
opposed to low-level Ax abstractions), ensuring the full
170+
experiment configuration is known upfront.
171+
- 'all_inputs_are_configs': Deprecated alias for 'uses_standard_api',
172+
supported for backward compatibility.
173+
early_stopping_strategy: The early stopping strategy, if any. Used to
174+
determine if early stopping is enabled. Defaults to None.
175+
global_stopping_strategy: The global stopping strategy, if any. Used to
176+
determine if global stopping is enabled. Defaults to None.
177+
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
178+
the whole optimization ending. Defaults to 0.5.
179+
max_pending_trials: Maximum number of pending trials. Defaults to 10.
180+
min_failed_trials_for_failure_rate_check: Minimum failed trials before
181+
failure rate is checked. Defaults to 5.
156182
157183
Returns:
158-
A dictionary summarizing the experiment.
184+
An OptimizationSummary containing experiment complexity metrics.
159185
"""
160186
search_space = experiment.search_space
161187
optimization_config = experiment.optimization_config
@@ -179,8 +205,8 @@ def summarize_ax_optimization_complexity(
179205
else 1
180206
)
181207
num_outcome_constraints = len(optimization_config.outcome_constraints)
182-
uses_early_stopping = options.early_stopping_strategy is not None
183-
uses_global_stopping = options.global_stopping_strategy is not None
208+
uses_early_stopping = early_stopping_strategy is not None
209+
uses_global_stopping = global_stopping_strategy is not None
184210

185211
# Check if any metrics use merge_multiple_curves
186212
uses_merge_multiple_curves = False
@@ -210,10 +236,10 @@ def summarize_ax_optimization_complexity(
210236
uses_global_stopping=uses_global_stopping,
211237
uses_merge_multiple_curves=uses_merge_multiple_curves,
212238
uses_standard_api=uses_standard_api,
213-
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
214-
max_pending_trials=options.max_pending_trials,
239+
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
240+
max_pending_trials=max_pending_trials,
215241
min_failed_trials_for_failure_rate_check=(
216-
options.min_failed_trials_for_failure_rate_check
242+
min_failed_trials_for_failure_rate_check
217243
),
218244
)
219245

ax/utils/common/tests/test_complexity_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ax.core.metric import Metric
1010
from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError
11-
from ax.service.orchestrator import OrchestratorOptions
1211
from ax.utils.common.complexity_utils import (
1312
check_if_in_standard,
1413
DEFAULT_TIER_MESSAGES,
@@ -30,7 +29,6 @@ class TestSummarizeAxOptimizationComplexity(TestCase):
3029
def setUp(self) -> None:
3130
super().setUp()
3231
self.experiment = get_experiment()
33-
self.options = OrchestratorOptions()
3432
self.tier_metadata: dict[str, object] = {}
3533

3634
def test_basic_experiment_summary(self) -> None:
@@ -39,7 +37,6 @@ def test_basic_experiment_summary(self) -> None:
3937
# WHEN we summarize the experiment
4038
summary = summarize_ax_optimization_complexity(
4139
experiment=self.experiment,
42-
options=self.options,
4340
tier_metadata=self.tier_metadata,
4441
)
4542

@@ -58,7 +55,6 @@ def test_multi_objective_experiment(self) -> None:
5855
# WHEN we summarize the experiment
5956
summary = summarize_ax_optimization_complexity(
6057
experiment=experiment,
61-
options=self.options,
6258
tier_metadata=self.tier_metadata,
6359
)
6460

@@ -76,7 +72,6 @@ def test_experiment_without_optimization_config_raises(self) -> None:
7672
):
7773
summarize_ax_optimization_complexity(
7874
experiment=self.experiment,
79-
options=self.options,
8075
tier_metadata=self.tier_metadata,
8176
)
8277

@@ -107,42 +102,20 @@ def test_tier_metadata_extraction(self) -> None:
107102
# WHEN we summarize the experiment
108103
summary = summarize_ax_optimization_complexity(
109104
experiment=self.experiment,
110-
options=self.options,
111105
tier_metadata=tier_metadata,
112106
)
113107

114108
# THEN the summary should reflect tier metadata values
115109
self.assertEqual(summary.max_trials, expected_max_trials)
116110
self.assertEqual(summary.uses_standard_api, expected_all_configs)
117111

118-
def test_orchestrator_options_extraction(self) -> None:
119-
# GIVEN custom orchestrator options
120-
options = OrchestratorOptions(
121-
tolerated_trial_failure_rate=0.25,
122-
max_pending_trials=5,
123-
min_failed_trials_for_failure_rate_check=10,
124-
)
125-
126-
# WHEN we summarize the experiment
127-
summary = summarize_ax_optimization_complexity(
128-
experiment=self.experiment,
129-
options=options,
130-
tier_metadata=self.tier_metadata,
131-
)
132-
133-
# THEN the summary should reflect orchestrator options
134-
self.assertEqual(summary.tolerated_trial_failure_rate, 0.25)
135-
self.assertEqual(summary.max_pending_trials, 5)
136-
self.assertEqual(summary.min_failed_trials_for_failure_rate_check, 10)
137-
138112
def test_parameter_constraints_counted(self) -> None:
139113
# GIVEN an experiment with parameter constraints
140114
experiment = get_experiment(constrain_search_space=True)
141115

142116
# WHEN we summarize the experiment
143117
summary = summarize_ax_optimization_complexity(
144118
experiment=experiment,
145-
options=self.options,
146119
tier_metadata=self.tier_metadata,
147120
)
148121

@@ -159,7 +132,6 @@ def test_merge_multiple_curves_detection(self) -> None:
159132
# WHEN we summarize the experiment
160133
summary = summarize_ax_optimization_complexity(
161134
experiment=self.experiment,
162-
options=self.options,
163135
tier_metadata=self.tier_metadata,
164136
)
165137

0 commit comments

Comments
 (0)