Skip to content

Commit 61ad6d6

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
(5/6) Port helpers to OSS for the new Complexity Rating Healthcheck - Renaming and rewording for OSS (#4653)
Summary: Pull Request resolved: #4653 Reviewed By: bernardbeckerman Differential Revision: D88902863
1 parent 1820855 commit 61ad6d6

File tree

2 files changed

+60
-54
lines changed

2 files changed

+60
-54
lines changed

ax/utils/common/complexity_utils.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
WHEELHOUSE_TIER_MESSAGE = """This experiment is in tier 'Wheelhouse'.
1919
20-
Experiments belonging to this tier should not run into any problems! If an issue \
21-
does occur, please post to our github issues page.
20+
Experiments in the 'Wheelhouse' tier use standard features that are \
21+
thoroughly tested and should work reliably. If you encounter any issues, \
22+
they are typically easy to diagnose and resolve. Otherwise, please post \
23+
to our github issues page for help.
2224
"""
2325

2426
ADVANCED_TIER_MESSAGE = """This experiment is in tier 'Advanced'.
@@ -33,25 +35,26 @@
3335
UNSUPPORTED_TIER_MESSAGE = """This experiment is in tier 'Unsupported'.
3436
3537
You are pushing Ax beyond its limits. Please post to our github issues page for help \
36-
in improving/simplifying your configuration to conform to a more \
37-
well-supported usage tier if possible.
38+
with improving/simplifying your configuration to conform to a more \
39+
well-supported usage tier if possible. We strongly recommend simplifying your \
40+
configuration to fall within a supported tier.
3841
"""
3942

40-
WIKI_TIER_MESSAGE = "https://ax.dev/docs/why-ax"
41-
4243
UNKNOWN_TIER_MESSAGE = """Failed to determine the tier of this experiment.
4344
4445
Please post on our github issues page or reach out to the Ax user group \
4546
to determine the support tier of your workflow.
47+
This may indicate an issue with your experiment configuration or an internal \
48+
error during tier classification. Please review your configuration for any \
49+
unusual settings, or consult our github issues page documentation for guidance.
4650
"""
4751

48-
NOT_STANDARD_API_MESSAGE = (
49-
"The experiment summary indicates that this workflow is not using a standard \
50-
API (`uses_standard_api=False`). Tier classification works best when the full \
51-
experiment configuration is known upfront. If you are building a tool on top \
52-
of this function, ensure that `uses_standard_api` is set to `True` in the \
53-
`OptimizationSummary` when your tool uses a standard API."
54-
)
52+
NOT_STANDARD_API_MESSAGE = """The experiment summary indicates that this workflow \
53+
is not using a standard API (`uses_standard_api=False`). Tier classification works \
54+
best when the full experiment configuration is known upfront. If you are building a \
55+
tool on top of this function, ensure that `uses_standard_api` is set to `True` in the \
56+
`OptimizationSummary` when your tool uses a standard API.
57+
"""
5558

5659

5760
@dataclass(frozen=True)
@@ -68,11 +71,11 @@ class OptimizationSummary:
6871
num_outcome_constraints: Number of outcome constraints.
6972
uses_early_stopping: Whether early stopping is enabled.
7073
uses_global_stopping: Whether global stopping is enabled.
71-
all_inputs_are_configs: Whether all inputs are high-level configs
74+
uses_standard_api: Whether all inputs are high-level configs
7275
(as opposed to low-level Ax abstractions).
7376
7477
Optional Keys:
75-
max_trials: Maximum number of trials (required if all_inputs_are_configs
78+
max_trials: Maximum number of trials (required if uses_standard_api
7679
is True).
7780
tolerated_trial_failure_rate: Maximum tolerated trial failure rate
7881
(should be <= 0.9).
@@ -94,7 +97,7 @@ class OptimizationSummary:
9497
num_outcome_constraints: int
9598
uses_early_stopping: bool
9699
uses_global_stopping: bool
97-
all_inputs_are_configs: bool
100+
uses_standard_api: bool
98101
# Optional keys
99102
max_trials: int | None = None
100103
tolerated_trial_failure_rate: float | None = None
@@ -158,6 +161,11 @@ def summarize_ax_optimization_complexity(
158161
uses_merge_multiple_curves = True
159162
break
160163

164+
# Support both new key and old key for backward compatibility
165+
uses_standard_api = tier_metadata.get("uses_standard_api")
166+
if uses_standard_api is None:
167+
uses_standard_api = tier_metadata.get("all_inputs_are_configs", False)
168+
161169
return OptimizationSummary(
162170
max_trials=max_trials,
163171
num_params=num_params,
@@ -170,7 +178,7 @@ def summarize_ax_optimization_complexity(
170178
uses_early_stopping=uses_early_stopping,
171179
uses_global_stopping=uses_global_stopping,
172180
uses_merge_multiple_curves=uses_merge_multiple_curves,
173-
all_inputs_are_configs=tier_metadata.get("all_inputs_are_configs", False),
181+
uses_standard_api=uses_standard_api,
174182
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
175183
max_pending_trials=options.max_pending_trials,
176184
min_failed_trials_for_failure_rate_check=(
@@ -331,7 +339,7 @@ def _check_if_is_in_wheelhouse_other_settings(
331339
"""
332340
is_in_wheelhouse, is_supported = True, True
333341
max_trials = optimization_summary.max_trials
334-
if not optimization_summary.all_inputs_are_configs:
342+
if not optimization_summary.uses_standard_api:
335343
is_in_wheelhouse, is_supported = False, False
336344
why_not_supported.append(NOT_STANDARD_API_MESSAGE)
337345
elif max_trials is None:
@@ -417,7 +425,7 @@ def check_if_in_wheelhouse(
417425
num_categorical_6_inf, num_parameter_constraints
418426
- Optimization config: num_objectives, num_outcome_constraints
419427
- Other settings: max_trials, uses_early_stopping, uses_global_stopping,
420-
all_inputs_are_configs, tolerated_trial_failure_rate, max_pending_trials,
428+
uses_standard_api, tolerated_trial_failure_rate, max_pending_trials,
421429
min_failed_trials_for_failure_rate_check, non_default_advanced_options,
422430
uses_merge_multiple_curves
423431
@@ -516,9 +524,4 @@ def format_tier_message(
516524
f"\n{why_msg}\n"
517525
)
518526
msg += why_msg
519-
520-
msg += (
521-
"\n\nFor more information about the definition of each tier and what "
522-
f"level of support you can expect: {WIKI_TIER_MESSAGE}"
523-
)
524527
return msg

ax/utils/common/tests/test_complexity_utils.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_tier_metadata_extraction(self) -> None:
8686
test_cases = [
8787
(
8888
"with_values",
89-
{"user_supplied_max_trials": 50, "all_inputs_are_configs": True},
89+
{"user_supplied_max_trials": 50, "uses_standard_api": True},
9090
50,
9191
True,
9292
),
@@ -114,7 +114,7 @@ def test_tier_metadata_extraction(self) -> None:
114114

115115
# THEN the summary should reflect tier metadata values
116116
self.assertEqual(summary.max_trials, expected_max_trials)
117-
self.assertEqual(summary.all_inputs_are_configs, expected_all_configs)
117+
self.assertEqual(summary.uses_standard_api, expected_all_configs)
118118

119119
def test_orchestrator_options_extraction(self) -> None:
120120
# GIVEN custom orchestrator options
@@ -226,7 +226,7 @@ def test_unknown_tier_raises_error(self) -> None:
226226
)
227227

228228

229-
def get_experiment_summary(
229+
def get_optimization_summary(
230230
max_trials: int | None = 100,
231231
num_params: int = 10,
232232
num_binary: int = 0,
@@ -237,7 +237,7 @@ def get_experiment_summary(
237237
num_outcome_constraints: int = 0,
238238
uses_early_stopping: bool = False,
239239
uses_global_stopping: bool = False,
240-
all_inputs_are_configs: bool = True,
240+
uses_standard_api: bool = True,
241241
tolerated_trial_failure_rate: float | None = 0.5,
242242
max_pending_trials: int | None = 5,
243243
min_failed_trials_for_failure_rate_check: int | None = 5,
@@ -256,7 +256,7 @@ def get_experiment_summary(
256256
num_outcome_constraints=num_outcome_constraints,
257257
uses_early_stopping=uses_early_stopping,
258258
uses_global_stopping=uses_global_stopping,
259-
all_inputs_are_configs=all_inputs_are_configs,
259+
uses_standard_api=uses_standard_api,
260260
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
261261
max_pending_trials=max_pending_trials,
262262
min_failed_trials_for_failure_rate_check=(
@@ -272,7 +272,7 @@ class TestCheckIfInWheelhouse(TestCase):
272272

273273
def setUp(self) -> None:
274274
super().setUp()
275-
self.base_summary = get_experiment_summary()
275+
self.base_summary = get_optimization_summary()
276276

277277
def test_wheelhouse_tier_for_simple_experiment(self) -> None:
278278
"""Test that a simple experiment is classified as Wheelhouse tier."""
@@ -287,28 +287,28 @@ def test_wheelhouse_tier_for_simple_experiment(self) -> None:
287287
def test_advanced_tier_conditions(self) -> None:
288288
"""Test conditions that result in Advanced tier."""
289289
test_cases: list[tuple[OptimizationSummary, str]] = [
290-
(get_experiment_summary(max_trials=250), "250 total trials"),
291-
(get_experiment_summary(num_params=60), "60 tunable parameter(s)"),
292-
(get_experiment_summary(num_binary=75), "75 binary tunable parameter(s)"),
290+
(get_optimization_summary(max_trials=250), "250 total trials"),
291+
(get_optimization_summary(num_params=60), "60 tunable parameter(s)"),
292+
(get_optimization_summary(num_binary=75), "75 binary tunable parameter(s)"),
293293
(
294-
get_experiment_summary(num_categorical_3_5=1),
294+
get_optimization_summary(num_categorical_3_5=1),
295295
"1 unordered choice parameter(s)",
296296
),
297297
(
298-
get_experiment_summary(num_parameter_constraints=4),
298+
get_optimization_summary(num_parameter_constraints=4),
299299
"4 parameter constraints",
300300
),
301-
(get_experiment_summary(num_objectives=3), "3 objectives"),
301+
(get_optimization_summary(num_objectives=3), "3 objectives"),
302302
(
303-
get_experiment_summary(num_outcome_constraints=3),
303+
get_optimization_summary(num_outcome_constraints=3),
304304
"3 outcome constraints",
305305
),
306306
(
307-
get_experiment_summary(uses_early_stopping=True),
307+
get_optimization_summary(uses_early_stopping=True),
308308
"Early stopping is enabled",
309309
),
310310
(
311-
get_experiment_summary(uses_global_stopping=True),
311+
get_optimization_summary(uses_global_stopping=True),
312312
"Global stopping is enabled",
313313
),
314314
]
@@ -327,44 +327,47 @@ def test_advanced_tier_conditions(self) -> None:
327327
def test_unsupported_tier_conditions(self) -> None:
328328
"""Test conditions that result in Unsupported tier."""
329329
test_cases: list[tuple[OptimizationSummary, str]] = [
330-
(get_experiment_summary(max_trials=510), "510 total trials"),
331-
(get_experiment_summary(num_params=201), "201 tunable parameter(s)"),
332-
(get_experiment_summary(num_binary=101), "101 binary tunable parameter(s)"),
330+
(get_optimization_summary(max_trials=510), "510 total trials"),
331+
(get_optimization_summary(num_params=201), "201 tunable parameter(s)"),
333332
(
334-
get_experiment_summary(num_categorical_3_5=6),
333+
get_optimization_summary(num_binary=101),
334+
"101 binary tunable parameter(s)",
335+
),
336+
(
337+
get_optimization_summary(num_categorical_3_5=6),
335338
"unordered choice parameters with more than 3 options",
336339
),
337340
(
338-
get_experiment_summary(num_categorical_6_inf=2),
341+
get_optimization_summary(num_categorical_6_inf=2),
339342
"unordered choice parameters with more than 5 options",
340343
),
341344
(
342-
get_experiment_summary(num_parameter_constraints=6),
345+
get_optimization_summary(num_parameter_constraints=6),
343346
"6 parameter constraints",
344347
),
345-
(get_experiment_summary(num_objectives=5), "5 objectives"),
348+
(get_optimization_summary(num_objectives=5), "5 objectives"),
346349
(
347-
get_experiment_summary(num_outcome_constraints=6),
350+
get_optimization_summary(num_outcome_constraints=6),
348351
"6 outcome constraints",
349352
),
350353
(
351-
get_experiment_summary(all_inputs_are_configs=False),
354+
get_optimization_summary(uses_standard_api=False),
352355
"uses_standard_api=False",
353356
),
354357
(
355-
get_experiment_summary(tolerated_trial_failure_rate=0.99),
358+
get_optimization_summary(tolerated_trial_failure_rate=0.99),
356359
"tolerated_trial_failure_rate=0.99",
357360
),
358361
(
359-
get_experiment_summary(non_default_advanced_options=True),
362+
get_optimization_summary(non_default_advanced_options=True),
360363
"Non-default advanced_options",
361364
),
362365
(
363-
get_experiment_summary(uses_merge_multiple_curves=True),
366+
get_optimization_summary(uses_merge_multiple_curves=True),
364367
"merge_multiple_curves=True",
365368
),
366369
(
367-
get_experiment_summary(
370+
get_optimization_summary(
368371
max_pending_trials=3, min_failed_trials_for_failure_rate_check=7
369372
),
370373
"min_failed_trials_for_failure_rate_check=7",
@@ -380,8 +383,8 @@ def test_unsupported_tier_conditions(self) -> None:
380383
self.assertIn(expected_msg, why_not_supported[0])
381384

382385
def test_max_trials_none_raises(self) -> None:
383-
"""Test max_trials=None with all_inputs_are_configs=True raises error."""
384-
summary = get_experiment_summary(all_inputs_are_configs=True, max_trials=None)
386+
"""Test max_trials=None with uses_standard_api=True raises error."""
387+
summary = get_optimization_summary(uses_standard_api=True, max_trials=None)
385388

386389
with self.assertRaisesRegex(UserInputError, "`max_trials` should not be None!"):
387390
check_if_in_wheelhouse(summary)

0 commit comments

Comments
 (0)