|
17 | 17 | from ax.core.trial_status import TrialStatus |
18 | 18 | from ax.exceptions.core import UserInputError |
19 | 19 | from ax.utils.common.testutils import TestCase |
20 | | -from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments |
| 20 | +from ax.utils.testing.core_stubs import ( |
| 21 | + get_non_failed_arm_names, |
| 22 | + get_offline_experiments, |
| 23 | + get_online_experiments, |
| 24 | +) |
21 | 25 | from ax.utils.testing.mock import mock_botorch_optimize |
22 | 26 | from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node |
23 | 27 | from pyre_extensions import assert_is_instance, none_throws |
@@ -69,9 +73,13 @@ def setUp(self) -> None: |
69 | 73 |
|
70 | 74 | def test_trial_statuses_behavior(self) -> None: |
71 | 75 | # When neither trial_statuses nor trial_index is provided, |
72 | | - # should use default statuses (excluding ABANDONED and STALE) |
| 76 | + # should use default statuses (excluding ABANDONED, STALE, and FAILED) |
73 | 77 | analysis = ScatterPlot(x_metric_name="foo", y_metric_name="bar") |
74 | | - expected_statuses = {*TrialStatus} - {TrialStatus.ABANDONED, TrialStatus.STALE} |
| 78 | + expected_statuses = {*TrialStatus} - { |
| 79 | + TrialStatus.ABANDONED, |
| 80 | + TrialStatus.STALE, |
| 81 | + TrialStatus.FAILED, |
| 82 | + } |
75 | 83 | self.assertEqual(set(none_throws(analysis.trial_statuses)), expected_statuses) |
76 | 84 |
|
77 | 85 | # When trial_statuses is explicitly provided, it should be used |
@@ -133,9 +141,11 @@ def test_compute_raw(self) -> None: |
133 | 141 | ) |
134 | 142 | self.assertIsNotNone(card.blob) |
135 | 143 |
|
136 | | - # Check that we have one row per arm and that each arm appears only once |
137 | | - self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name)) |
138 | | - for arm_name in self.client._experiment.arms_by_name: |
| 144 | + # Check that we have one row per arm from non-failed trials and that each |
| 145 | + # arm appears only once |
| 146 | + non_failed_arms = get_non_failed_arm_names(self.client._experiment) |
| 147 | + self.assertEqual(len(card.df), len(non_failed_arms)) |
| 148 | + for arm_name in non_failed_arms: |
139 | 149 | self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1) |
140 | 150 |
|
141 | 151 | # Check that all SEMs are NaN |
@@ -191,9 +201,11 @@ def test_compute_with_modeled(self) -> None: |
191 | 201 |
|
192 | 202 | self.assertIsNotNone(card.blob) |
193 | 203 |
|
194 | | - # Check that we have one row per arm and that each arm appears only once |
195 | | - self.assertEqual(len(card.df), len(self.client._experiment.arms_by_name)) |
196 | | - for arm_name in self.client._experiment.arms_by_name: |
| 204 | + # Check that we have one row per arm from non-failed trials and that each |
| 205 | + # arm appears only once |
| 206 | + non_failed_arms = get_non_failed_arm_names(self.client._experiment) |
| 207 | + self.assertEqual(len(card.df), len(non_failed_arms)) |
| 208 | + for arm_name in non_failed_arms: |
197 | 209 | self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1) |
198 | 210 |
|
199 | 211 | # Check that all SEMs are not NaN |
|
0 commit comments