Skip to content

Commit 9a00edc

Browse files
shrutipatel31meta-codesync[bot]
authored andcommitted
Filter failed trials from plots (#4725)
Summary: Pull Request resolved: #4725 This diff updates the default trial status filtering in Ax analysis plots to also exclude `FAILED` trials. Previously, the analysis plots (`ArmEffectsPlot`, `ObjectivePFeasibleFrontierPlot`, `ScatterPlot`) excluded only `ABANDONED` and `STALE` trials by default. This change adds `FAILED` to the exclusion list because failed trials typically don't have valid observation data and shouldn't be included in visualizations. Updated `get_trial_statuses_with_fallback()` in `utils.py` to exclude `TrialStatus.FAILED` from the default set Reviewed By: bernardbeckerman Differential Revision: D89913203 Privacy Context Container: L1307644 fbshipit-source-id: 7c67b7a8b448e09f7b3fb7ce293648235df92b6a
1 parent fd38b65 commit 9a00edc

File tree

9 files changed

+68
-26
lines changed

9 files changed

+68
-26
lines changed

ax/analysis/plotly/arm_effects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
against the status quo arm from the same trial.
9292
trial_index: If present, only use arms from the trial with the given index.
9393
trial_statuses: If present, only use arms from trials with the given
94-
statuses. By default, exclude STALE and ABANDONED trials.
94+
statuses. By default, exclude STALE, ABANDONED, and FAILED trials.
9595
additional_arms: If present, include these arms in the plot in addition to
9696
the arms in the experiment. These arms will be marked as belonging to a
9797
trial with index -1.

ax/analysis/plotly/objective_p_feasible_frontier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
label: A label to use in the plot in place of the metric name.
8383
trial_index: If present, only use arms from the trial with the given index.
8484
trial_statuses: If present, only use arms from trials with the given
85-
statuses. By default, exclude STALE and ABANDONED trials.
85+
statuses. By default, exclude STALE, ABANDONED, and FAILED trials.
8686
num_points_to_generate: The number of points to generate on the frontier.
8787
Ideally this should be sufficiently large to provide a frontier with
8888
reasonably good coverage.

ax/analysis/plotly/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
against the status quo arm from the same trial.
154154
trial_index: If present, only use arms from the trial with the given index.
155155
trial_statuses: If present, only use arms from trials with the given
156-
statuses. By default, exclude STALE and ABANDONED trials.
156+
statuses. By default, exclude STALE, FAILED and ABANDONED trials.
157157
additional_arms: If present, include these arms in the plot in addition to
158158
the arms in the experiment. These arms will be marked as belonging to a
159159
trial with index -1.

ax/analysis/plotly/tests/test_arm_effects.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ax.utils.common.testutils import TestCase
1919
from ax.utils.testing.core_stubs import (
2020
get_branin_experiment,
21+
get_non_failed_arm_names,
2122
get_offline_experiments,
2223
get_online_experiments,
2324
)
@@ -75,9 +76,13 @@ def setUp(self) -> None:
7576

7677
def test_trial_statuses_behavior(self) -> None:
7778
# When neither trial_statuses nor trial_index is provided,
78-
# should use default statuses (excluding ABANDONED and STALE)
79+
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
7980
analysis = ArmEffectsPlot(metric_name="foo")
80-
expected_statuses = {*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE}
81+
expected_statuses = {*TrialStatus} - {
82+
TrialStatus.ABANDONED,
83+
TrialStatus.STALE,
84+
TrialStatus.FAILED,
85+
}
8186
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
8287

8388
# When trial_statuses is explicitly provided, it should be used
@@ -129,9 +134,11 @@ def test_compute_raw(self) -> None:
129134
},
130135
)
131136

132-
# Check that we have one row per arm and that each arm appears only once
133-
self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name))
134-
for arm_name in self.client._experiment.arms_by_name:
137+
# Check that we have one row per arm from non-failed trials and that each
138+
# arm appears only once
139+
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
140+
self.assertEqual(len(card.df), len(non_failed_arms))
141+
for arm_name in non_failed_arms:
135142
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
136143

137144
# Check that all SEMs are NaN
@@ -158,9 +165,11 @@ def test_compute_with_modeled(self) -> None:
158165
},
159166
)
160167

161-
# Check that we have one row per arm and that each arm appears only once
162-
self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name))
163-
for arm_name in self.client._experiment.arms_by_name:
168+
# Check that we have one row per arm from non-failed trials and that each
169+
# arm appears only once
170+
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
171+
self.assertEqual(len(card.df), len(non_failed_arms))
172+
for arm_name in non_failed_arms:
164173
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
165174

166175
# Check that all SEMs are not NaN

ax/analysis/plotly/tests/test_objective_p_feasible_frontier.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ def setUp(self) -> None:
5050

5151
def test_trial_statuses_behavior(self) -> None:
5252
# When neither trial_statuses nor trial_index is provided,
53-
# should use default statuses (excluding ABANDONED and STALE)
53+
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
5454
analysis = ObjectivePFeasibleFrontierPlot()
55-
expected_statuses = {*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE}
55+
expected_statuses = {*TrialStatus} - {
56+
TrialStatus.ABANDONED,
57+
TrialStatus.STALE,
58+
TrialStatus.FAILED,
59+
}
5660
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
5761

5862
# When trial_statuses is explicitly provided, it should be used

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from ax.core.trial_status import TrialStatus
1818
from ax.exceptions.core import UserInputError
1919
from ax.utils.common.testutils import TestCase
20-
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
20+
from ax.utils.testing.core_stubs import (
21+
get_non_failed_arm_names,
22+
get_offline_experiments,
23+
get_online_experiments,
24+
)
2125
from ax.utils.testing.mock import mock_botorch_optimize
2226
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node
2327
from pyre_extensions import assert_is_instance, none_throws
@@ -69,9 +73,13 @@ def setUp(self) -> None:
6973

7074
def test_trial_statuses_behavior(self) -> None:
7175
# When neither trial_statuses nor trial_index is provided,
72-
# should use default statuses (excluding ABANDONED and STALE)
76+
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
7377
analysis = ScatterPlot(x_metric_name="foo", y_metric_name="bar")
74-
expected_statuses = {*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE}
78+
expected_statuses = {*TrialStatus} - {
79+
TrialStatus.ABANDONED,
80+
TrialStatus.STALE,
81+
TrialStatus.FAILED,
82+
}
7583
self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses)
7684

7785
# When trial_statuses is explicitly provided, it should be used
@@ -133,9 +141,11 @@ def test_compute_raw(self) -> None:
133141
)
134142
self.assertIsNotNone(card.blob)
135143

136-
# Check that we have one row per arm and that each arm appears only once
137-
self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name))
138-
for arm_name in self.client._experiment.arms_by_name:
144+
# Check that we have one row per arm from non-failed trials and that each
145+
# arm appears only once
146+
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
147+
self.assertEqual(len(card.df), len(non_failed_arms))
148+
for arm_name in non_failed_arms:
139149
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
140150

141151
# Check that all SEMs are NaN
@@ -191,9 +201,11 @@ def test_compute_with_modeled(self) -> None:
191201

192202
self.assertIsNotNone(card.blob)
193203

194-
# Check that we have one row per arm and that each arm appears only once
195-
self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name))
196-
for arm_name in self.client._experiment.arms_by_name:
204+
# Check that we have one row per arm from non-failed trials and that each
205+
# arm appears only once
206+
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
207+
self.assertEqual(len(card.df), len(non_failed_arms))
208+
for arm_name in non_failed_arms:
197209
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
198210

199211
# Check that all SEMs are not NaN

ax/analysis/plotly/tests/test_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,20 @@ def test_get_trial_statuses_with_fallback_with_trial_index(self) -> None:
3333

3434
def test_get_trial_statuses_with_fallback_default(self) -> None:
3535
# When neither trial_statuses nor trial_index is provided,
36-
# should return all statuses except ABANDONED and STALE
36+
# should return all statuses except ABANDONED, STALE, and FAILED
3737
result = none_throws(
3838
get_trial_statuses_with_fallback(trial_statuses=None, trial_index=None)
3939
)
4040

41-
expected_statuses = {*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE}
41+
expected_statuses = {*TrialStatus} - {
42+
TrialStatus.ABANDONED,
43+
TrialStatus.STALE,
44+
TrialStatus.FAILED,
45+
}
4246
self.assertEqual(set(result), expected_statuses)
4347
self.assertNotIn(TrialStatus.ABANDONED, result)
4448
self.assertNotIn(TrialStatus.STALE, result)
49+
self.assertNotIn(TrialStatus.FAILED, result)
4550

4651
def test_get_trial_statuses_with_fallback_explicit_takes_precedence(self) -> None:
4752
# When both trial_statuses and trial_index are provided,

ax/analysis/plotly/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,14 @@ def get_trial_statuses_with_fallback(
224224
) -> list[TrialStatus] | None:
225225
"""Get the default trial statuses to plot.
226226
227-
By default, include all trials except those that are abandoned or stale.
227+
By default, include all trials except those that are abandoned, stale, or failed.
228228
If trial_index is provided, then we only filter based on trial_index,
229229
and therefore this function returns None.
230230
"""
231231
if trial_index is not None:
232232
return None
233233
elif trial_statuses is not None:
234234
return [*trial_statuses]
235-
return [*{*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE}]
235+
return [
236+
*{*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE, TrialStatus.FAILED}
237+
]

ax/utils/testing/core_stubs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,16 @@ def add_arm(
18171817
return self
18181818

18191819

1820+
def get_non_failed_arm_names(experiment: Experiment) -> set[str]:
1821+
"""Get the names of all arms from non-failed trials."""
1822+
return {
1823+
arm.name
1824+
for trial in experiment.trials.values()
1825+
if trial.status != TrialStatus.FAILED
1826+
for arm in trial.arms
1827+
}
1828+
1829+
18201830
##############################
18211831
# Parameters
18221832
##############################

0 commit comments

Comments
 (0)