Skip to content

Commit ad37fc1

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Show all healthcheck cards sorted by severity and priority
Summary: Previously, the overview analysis only displayed non-passing healthcheck cards. This change shows all healthcheck cards, sorted by severity and priority to surface the most important information first. The sorting order is: 1. ErrorAnalysisCard (errors during computation) 2. FAIL status 3. WARNING status 4. PASS status with priority (BaselineImprovementAnalysis, EarlyStoppingAnalysis - these provide valuable progress metrics even when passing) 5. PASS status (rest) This gives users visibility into the full health of their experiment while keeping critical issues at the top. Differential Revision: D91750384
1 parent 5190f4b commit ad37fc1

File tree

3 files changed

+112
-12
lines changed

3 files changed

+112
-12
lines changed

ax/analysis/healthcheck/healthcheck_analysis.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
# pyre-strict
77

8+
from __future__ import annotations
9+
810
import json
911
from enum import IntEnum
1012

1113
import pandas as pd
12-
from ax.core.analysis_card import AnalysisCard
14+
from ax.analysis.analysis import ErrorAnalysisCard
15+
from ax.core.analysis_card import AnalysisCard, AnalysisCardBase
1316

1417

1518
class HealthcheckStatus(IntEnum):
@@ -18,6 +21,13 @@ class HealthcheckStatus(IntEnum):
1821
WARNING = 2
1922

2023

24+
# Healthchecks that provide valuable progress info even when passing
25+
PRIORITY_HEALTHCHECKS: set[str] = {
26+
"BaselineImprovementAnalysis",
27+
"EarlyStoppingAnalysis",
28+
}
29+
30+
2131
class HealthcheckAnalysisCard(AnalysisCard):
2232
def get_status(self) -> HealthcheckStatus:
2333
return HealthcheckStatus(json.loads(self.blob)["status"])
@@ -49,3 +59,49 @@ def create_healthcheck_analysis_card(
4959
}
5060
),
5161
)
62+
63+
64+
# Status order for sorting: FAIL first, then WARNING, then PASS
65+
_STATUS_SORT_ORDER: dict[HealthcheckStatus, int] = {
66+
HealthcheckStatus.FAIL: 1,
67+
HealthcheckStatus.WARNING: 2,
68+
HealthcheckStatus.PASS: 3,
69+
}
70+
71+
72+
def sort_healthcheck_cards(
73+
cards: list[AnalysisCardBase],
74+
) -> list[AnalysisCardBase]:
75+
"""
76+
Sort healthcheck cards by severity and priority.
77+
78+
Order:
79+
1. ErrorAnalysisCard (errors during computation)
80+
2. FAIL status
81+
3. WARNING status
82+
4. PASS status with priority (BaselineImprovement, EarlyStopping, etc.)
83+
5. PASS status (rest)
84+
85+
Args:
86+
cards: List of analysis cards (typically HealthcheckAnalysisCard or
87+
ErrorAnalysisCard instances).
88+
89+
Returns:
90+
Sorted list of cards.
91+
"""
92+
93+
def sort_key(card: AnalysisCardBase) -> tuple[int, int, str]:
94+
if isinstance(card, ErrorAnalysisCard):
95+
return (0, 0, card.name)
96+
97+
if isinstance(card, HealthcheckAnalysisCard):
98+
return (
99+
_STATUS_SORT_ORDER[card.get_status()],
100+
0 if card.name in PRIORITY_HEALTHCHECKS else 1,
101+
card.name,
102+
)
103+
104+
# Fallback for type safety (unreachable in practice)
105+
return (4, 1, card.name)
106+
107+
return sorted(cards, key=sort_key)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
import pandas as pd
9+
from ax.analysis.analysis import ErrorAnalysisCard
10+
from ax.analysis.healthcheck.healthcheck_analysis import (
11+
create_healthcheck_analysis_card,
12+
HealthcheckStatus,
13+
sort_healthcheck_cards,
14+
)
15+
from ax.core.analysis_card import AnalysisCardBase
16+
from ax.utils.common.testutils import TestCase
17+
18+
19+
def _card(name: str, status: HealthcheckStatus) -> AnalysisCardBase:
20+
return create_healthcheck_analysis_card(
21+
name=name, title=name, subtitle=name, df=pd.DataFrame(), status=status
22+
)
23+
24+
25+
def _error(name: str) -> AnalysisCardBase:
26+
return ErrorAnalysisCard(
27+
name=name, title=name, subtitle=name, df=pd.DataFrame(), blob=""
28+
)
29+
30+
31+
class TestHealthcheckAnalysis(TestCase):
32+
def test_sort_ordering(self) -> None:
33+
cards: list[AnalysisCardBase] = [
34+
_card("RegularAnalysis", HealthcheckStatus.PASS),
35+
_card("WarningAnalysis", HealthcheckStatus.WARNING),
36+
_error("ErrorAnalysis"),
37+
_card("BaselineImprovementAnalysis", HealthcheckStatus.PASS),
38+
_card("FailAnalysis", HealthcheckStatus.FAIL),
39+
]
40+
result = sort_healthcheck_cards(cards)
41+
42+
self.assertEqual(
43+
[c.name for c in result],
44+
[
45+
"ErrorAnalysis",
46+
"FailAnalysis",
47+
"WarningAnalysis",
48+
"BaselineImprovementAnalysis",
49+
"RegularAnalysis",
50+
],
51+
)

ax/analysis/overview.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, final
99

1010
from ax.adapter.base import Adapter
11-
from ax.analysis.analysis import Analysis, ErrorAnalysisCard
11+
from ax.analysis.analysis import Analysis
1212
from ax.analysis.diagnostics import DiagnosticAnalysis
1313
from ax.analysis.healthcheck.baseline_improvement import BaselineImprovementAnalysis
1414
from ax.analysis.healthcheck.can_generate_candidates import (
@@ -19,7 +19,7 @@
1919
ConstraintsFeasibilityAnalysis,
2020
)
2121
from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis
22-
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard
22+
from ax.analysis.healthcheck.healthcheck_analysis import sort_healthcheck_cards
2323
from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis
2424
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
2525
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
@@ -247,21 +247,14 @@ def compute(
247247
if analyis is not None
248248
]
249249

250-
non_passing_health_checks = [
251-
card
252-
for card in health_check_cards
253-
if (isinstance(card, HealthcheckAnalysisCard) and not card.is_passing())
254-
or isinstance(card, ErrorAnalysisCard)
255-
]
256-
257250
health_checks_group = (
258251
AnalysisCardGroup(
259252
name="HealthchecksAnalysis",
260253
title=HEALTH_CHECK_CARDGROUP_TITLE,
261254
subtitle=HEALTH_CHECK_CARDGROUP_SUBTITLE,
262-
children=non_passing_health_checks,
255+
children=sort_healthcheck_cards(health_check_cards),
263256
)
264-
if len(non_passing_health_checks) > 0
257+
if len(health_check_cards) > 0
265258
else None
266259
)
267260

0 commit comments

Comments
 (0)