Skip to content

Commit ca36693

Browse files
shrutipatel31meta-codesync[bot]
authored andcommitted
Add NON_FAILED_STALE_ABANDONED_STATUSES to trial_statuses (facebook#4726)
Summary: Pull Request resolved: facebook#4726 This diff introduces a reusable constant `NON_FAILED_STALE_ABANDONED_STATUSES` in `trial_status.py` to replace the repeated pattern `{*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE, TrialStatus.FAILED}` used across multiple analysis files. Changes: - Added `NON_FAILED_STALE_ABANDONED_STATUSES` constant to `ax/core/trial_status.py` - Updated `get_trial_statuses_with_fallback()` in `utils.py` to use the new constant - Updated all test files to import and use the new constant Reviewed By: bernardbeckerman Differential Revision: D89928534 Privacy Context Container: L1307644 fbshipit-source-id: fcd41fea15621de7a9539895d5e9677ff51a28c8
1 parent 9a00edc commit ca36693

File tree

6 files changed

+25
-32
lines changed

6 files changed

+25
-32
lines changed

ax/analysis/plotly/tests/test_arm_effects.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ax.api.client import Client
1414
from ax.api.configs import RangeParameterConfig
1515
from ax.core.arm import Arm
16-
from ax.core.trial_status import TrialStatus
16+
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1717
from ax.exceptions.core import UserInputError
1818
from ax.utils.common.testutils import TestCase
1919
from ax.utils.testing.core_stubs import (
@@ -78,12 +78,10 @@ def test_trial_statuses_behavior(self) -> None:
7878
# When neither trial_statuses nor trial_index is provided,
7979
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
8080
analysis = ArmEffectsPlot(metric_name="foo")
81-
expected_statuses = {*TrialStatus} - {
82-
TrialStatus.ABANDONED,
83-
TrialStatus.STALE,
84-
TrialStatus.FAILED,
85-
}
86-
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
81+
self.assertEqual(
82+
set(none_throws(analysis.trial_statuses)),
83+
DEFAULT_ANALYSIS_STATUSES,
84+
)
8785

8886
# When trial_statuses is explicitly provided, it should be used
8987
explicit_statuses = [TrialStatus.COMPLETED, TrialStatus.RUNNING]

ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
OptimizationConfig,
2121
)
2222
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
23-
from ax.core.trial_status import TrialStatus
23+
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
2424
from ax.core.types import ComparisonOp
2525
from ax.utils.common.testutils import TestCase
2626
from ax.utils.testing.core_stubs import (
@@ -52,12 +52,10 @@ def test_trial_statuses_behavior(self) -> None:
5252
# When neither trial_statuses nor trial_index is provided,
5353
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
5454
analysis = ObjectivePFeasibleFrontierPlot()
55-
expected_statuses = {*TrialStatus} - {
56-
TrialStatus.ABANDONED,
57-
TrialStatus.STALE,
58-
TrialStatus.FAILED,
59-
}
60-
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
55+
self.assertEqual(
56+
set(none_throws(analysis.trial_statuses)),
57+
DEFAULT_ANALYSIS_STATUSES,
58+
)
6159

6260
# When trial_statuses is explicitly provided, it should be used
6361
explicit_statuses = [TrialStatus.COMPLETED, TrialStatus.RUNNING]

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ax.api.client import Client
1515
from ax.api.configs import RangeParameterConfig
1616
from ax.core.arm import Arm
17-
from ax.core.trial_status import TrialStatus
17+
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1818
from ax.exceptions.core import UserInputError
1919
from ax.utils.common.testutils import TestCase
2020
from ax.utils.testing.core_stubs import (
@@ -75,12 +75,10 @@ def test_trial_statuses_behavior(self) -> None:
7575
# When neither trial_statuses nor trial_index is provided,
7676
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
7777
analysis = ScatterPlot(x_metric_name="foo", y_metric_name="bar")
78-
expected_statuses = {*TrialStatus} - {
79-
TrialStatus.ABANDONED,
80-
TrialStatus.STALE,
81-
TrialStatus.FAILED,
82-
}
83-
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
78+
self.assertEqual(
79+
set(none_throws(analysis.trial_statuses)),
80+
DEFAULT_ANALYSIS_STATUSES,
81+
)
8482

8583
# When trial_statuses is explicitly provided, it should be used
8684
explicit_statuses = [TrialStatus.COMPLETED, TrialStatus.RUNNING]

ax/analysis/plotly/tests/test_utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_trial_statuses_with_fallback,
1212
trial_index_to_color,
1313
)
14-
from ax.core.trial_status import TrialStatus
14+
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1515
from ax.utils.common.testutils import TestCase
1616
from pyre_extensions import none_throws
1717

@@ -38,12 +38,7 @@ def test_get_trial_statuses_with_fallback_default(self) -> None:
3838
get_trial_statuses_with_fallback(trial_statuses=None, trial_index=None)
3939
)
4040

41-
expected_statuses = {*TrialStatus} - {
42-
TrialStatus.ABANDONED,
43-
TrialStatus.STALE,
44-
TrialStatus.FAILED,
45-
}
46-
self.assertEqual(set(result), expected_statuses)
41+
self.assertEqual(set(result), DEFAULT_ANALYSIS_STATUSES)
4742
self.assertNotIn(TrialStatus.ABANDONED, result)
4843
self.assertNotIn(TrialStatus.STALE, result)
4944
self.assertNotIn(TrialStatus.FAILED, result)

ax/analysis/plotly/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE, LIGHT_AX_BLUE
1414
from ax.core.experiment import Experiment
1515
from ax.core.objective import MultiObjective, ScalarizedObjective
16-
from ax.core.trial_status import TrialStatus
16+
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1717
from ax.exceptions.core import UnsupportedError
1818
from plotly import express as px
1919

@@ -232,6 +232,4 @@ def get_trial_statuses_with_fallback(
232232
return None
233233
elif trial_statuses is not None:
234234
return [*trial_statuses]
235-
return [
236-
*{*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE, TrialStatus.FAILED}
237-
]
235+
return [*DEFAULT_ANALYSIS_STATUSES]

ax/core/trial_status.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,9 @@ def __repr__(self) -> str:
173173
TrialStatus.STALE,
174174
TrialStatus.CANDIDATE,
175175
]
176+
177+
DEFAULT_ANALYSIS_STATUSES: set[TrialStatus] = set(TrialStatus) - {
178+
TrialStatus.ABANDONED,
179+
TrialStatus.STALE,
180+
TrialStatus.FAILED,
181+
}

0 commit comments

Comments
 (0)