Skip to content

Commit e23bcf2

Browse files
lukmazThe Meridian Authors
authored andcommitted
Implement health score computation
PiperOrigin-RevId: 871462554
1 parent c338825 commit e23bcf2

File tree

5 files changed

+474
-83
lines changed

5 files changed

+474
-83
lines changed

meridian/analysis/review/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,13 @@
4242
Q1 = "q1"
4343
Q3 = "q3"
4444
BAYESIAN_PPP = "bayesian_ppp"
45+
46+
# Health score constants
47+
R2_MIDPOINT = 0.5
48+
R2_STEEPNESS = 15
49+
FAIL_RATIO_POWER = 0.3
50+
HEALTH_SCORE_WEIGHT_BASELINE = 0.3
51+
HEALTH_SCORE_WEIGHT_BAYESIAN_PPP = 0.3
52+
HEALTH_SCORE_WEIGHT_GOF = 0.1
53+
HEALTH_SCORE_WEIGHT_PRIOR_POSTERIOR_SHIFT = 0.15
54+
HEALTH_SCORE_WEIGHT_ROI_CONSISTENCY = 0.15

meridian/analysis/review/results.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,13 @@ class ReviewSummary:
600600
overall_status: The overall status of all checks.
601601
summary_message: A summary message of all checks.
602602
results: A list of all check results.
603+
health_score: The health score of the model.
603604
"""
604605

605606
overall_status: Status
606607
summary_message: str
607608
results: list[CheckResult]
609+
health_score: float
608610

609611
def __repr__(self) -> str:
610612
report = []
@@ -613,6 +615,7 @@ def __repr__(self) -> str:
613615
report.append("=" * 40)
614616
report.append(f"Overall Status: {self.overall_status.name}")
615617
report.append(f"Summary: {self.summary_message}")
618+
report.append(f"Health Score: {self.health_score:.1f}")
616619
report.append("\nCheck Results:")
617620

618621
for result in self.results:

meridian/analysis/review/results_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,35 @@ def test_prior_posterior_shift_result_recommendation(
370370
self.assertEqual(result.recommendation, expected_recommendation)
371371

372372

373+
class ReviewSummaryTest(parameterized.TestCase):
374+
375+
def test_review_summary_repr(self):
376+
mock_result = results.ConvergenceCheckResult(
377+
case=results.ConvergenceCases.CONVERGED,
378+
config=configs.ConvergenceConfig(),
379+
max_rhat=1.0,
380+
max_parameter="mock_var",
381+
)
382+
summary = results.ReviewSummary(
383+
overall_status=results.Status.PASS,
384+
summary_message="summary",
385+
results=[mock_result],
386+
health_score=95.2,
387+
)
388+
expected_repr = """========================================
389+
Model Quality Checks
390+
========================================
391+
Overall Status: PASS
392+
Summary: summary
393+
Health Score: 95.2
394+
395+
Check Results:
396+
----------------------------------------
397+
Convergence Check:
398+
Status: PASS
399+
Recommendation: The model has likely converged, as all parameters have R-hat values < 1.2."""
400+
self.assertMultiLineEqual(str(summary), expected_repr)
401+
402+
373403
if __name__ == "__main__":
374404
absltest.main()

meridian/analysis/review/reviewer.py

Lines changed: 189 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414

1515
"""Implementation of the runner of the Model Quality Checks."""
1616

17+
from collections.abc import MutableMapping
18+
import dataclasses
1719
import typing
1820

1921
import immutabledict
2022
from meridian import constants
2123
from meridian.analysis import analyzer as analyzer_module
2224
from meridian.analysis.review import checks
2325
from meridian.analysis.review import configs
26+
from meridian.analysis.review import constants as review_constants
2427
from meridian.analysis.review import results
2528
from meridian.model import prior_distribution
26-
29+
import numpy as np
2730

2831
CheckType = typing.Type[checks.BaseCheck]
2932
ConfigInstance = configs.BaseConfig
@@ -38,6 +41,141 @@
3841
})
3942

4043

44+
def _get_baseline_score(
45+
baseline_check_result: results.BaselineCheckResult,
46+
) -> float:
47+
"""Returns the score of the Baseline check."""
48+
negative_baseline_prob = baseline_check_result.negative_baseline_prob
49+
baseline_config = baseline_check_result.config
50+
review_threshold = baseline_config.negative_baseline_prob_review_threshold
51+
fail_threshold = baseline_config.negative_baseline_prob_fail_threshold
52+
53+
return 100.0 * (
54+
1.0
55+
- np.clip(
56+
(negative_baseline_prob - review_threshold)
57+
/ (fail_threshold - review_threshold),
58+
0,
59+
1,
60+
)
61+
)
62+
63+
64+
def _get_bayesian_ppp_score(
65+
bayesian_ppp_check_result: results.BayesianPPPCheckResult,
66+
) -> float:
67+
"""Returns the score of the Bayesian PPP check."""
68+
bayesian_ppp = bayesian_ppp_check_result.bayesian_ppp
69+
bayesian_ppp_config = bayesian_ppp_check_result.config
70+
ppp_threshold = bayesian_ppp_config.ppp_threshold
71+
return 100.0 if bayesian_ppp > ppp_threshold else 0.0
72+
73+
74+
def _get_gof_score(
75+
goodness_of_fit_check_result: results.GoodnessOfFitCheckResult,
76+
) -> float:
77+
"""Returns the score of the Goodness of Fit check."""
78+
r_squared = goodness_of_fit_check_result.metrics.r_squared
79+
return 100.0 / (
80+
1
81+
+ np.exp(
82+
-review_constants.R2_STEEPNESS
83+
* (r_squared - review_constants.R2_MIDPOINT)
84+
)
85+
)
86+
87+
88+
def _get_pps_score(
89+
prior_posterior_shift_check_result: results.PriorPosteriorShiftCheckResult,
90+
) -> float:
91+
"""Returns the score of the Prior-Posterior Shift check."""
92+
prior_posterior_shift_ratio = len(
93+
prior_posterior_shift_check_result.no_shift_channels
94+
) / len(prior_posterior_shift_check_result.channel_results)
95+
return (
96+
100.0
97+
* (1.0 - np.clip(prior_posterior_shift_ratio, 0, 1))
98+
** review_constants.FAIL_RATIO_POWER
99+
)
100+
101+
102+
def _get_roi_consistency_score(
103+
roi_consistency_check_result: results.ROIConsistencyCheckResult,
104+
) -> float:
105+
"""Returns the score of the ROI Consistency check."""
106+
roi_consistency_failure_ratio = sum(
107+
1
108+
for r in roi_consistency_check_result.channel_results
109+
if r.case.status != results.Status.PASS
110+
) / len(roi_consistency_check_result.channel_results)
111+
return (
112+
100.0
113+
* (1.0 - np.clip(roi_consistency_failure_ratio, 0, 1))
114+
** review_constants.FAIL_RATIO_POWER
115+
)
116+
117+
118+
@dataclasses.dataclass(frozen=True)
119+
class _HealthScoreComponent:
120+
"""A component used in the calculation of the overall health score.
121+
122+
Attributes:
123+
check_type: The class of the check this component represents.
124+
score_function: A callable that takes the check result and returns a float
125+
score.
126+
result_type: The expected type of the result object for this check.
127+
weight: The weight of this component in the overall health score
128+
calculation.
129+
is_required: Whether this check is required to be present for the health
130+
score to be computed.
131+
"""
132+
133+
check_type: CheckType
134+
score_function: typing.Callable[[typing.Any], float]
135+
result_type: typing.Type[results.CheckResult]
136+
weight: float
137+
is_required: bool
138+
139+
140+
_HEALTH_SCORE_COMPONENTS = (
141+
_HealthScoreComponent(
142+
check_type=checks.BaselineCheck,
143+
score_function=_get_baseline_score,
144+
result_type=results.BaselineCheckResult,
145+
weight=review_constants.HEALTH_SCORE_WEIGHT_BASELINE,
146+
is_required=True,
147+
),
148+
_HealthScoreComponent(
149+
check_type=checks.BayesianPPPCheck,
150+
score_function=_get_bayesian_ppp_score,
151+
result_type=results.BayesianPPPCheckResult,
152+
weight=review_constants.HEALTH_SCORE_WEIGHT_BAYESIAN_PPP,
153+
is_required=True,
154+
),
155+
_HealthScoreComponent(
156+
check_type=checks.GoodnessOfFitCheck,
157+
score_function=_get_gof_score,
158+
result_type=results.GoodnessOfFitCheckResult,
159+
weight=review_constants.HEALTH_SCORE_WEIGHT_GOF,
160+
is_required=True,
161+
),
162+
_HealthScoreComponent(
163+
check_type=checks.PriorPosteriorShiftCheck,
164+
score_function=_get_pps_score,
165+
result_type=results.PriorPosteriorShiftCheckResult,
166+
weight=review_constants.HEALTH_SCORE_WEIGHT_PRIOR_POSTERIOR_SHIFT,
167+
is_required=False,
168+
),
169+
_HealthScoreComponent(
170+
check_type=checks.ROIConsistencyCheck,
171+
score_function=_get_roi_consistency_score,
172+
result_type=results.ROIConsistencyCheckResult,
173+
weight=review_constants.HEALTH_SCORE_WEIGHT_ROI_CONSISTENCY,
174+
is_required=False,
175+
),
176+
)
177+
178+
41179
class ModelReviewer:
42180
"""A tool for executing a series of quality checks on a Meridian model.
43181
@@ -57,15 +195,15 @@ def __init__(
57195
meridian,
58196
):
59197
self._meridian = meridian
60-
self._results: list[results.CheckResult] = []
198+
self._results: MutableMapping[CheckType, results.CheckResult] = {}
61199
self._analyzer = analyzer_module.Analyzer(
62200
model_context=meridian.model_context,
63201
inference_data=meridian.inference_data,
64202
)
65203

66-
def _run_and_handle(self, check_class, config):
67-
instance = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
68-
self._results.append(instance.run())
204+
def _run_and_handle(self, check_class: CheckType, config: configs.BaseConfig):
205+
instance: checks.BaseCheck = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
206+
self._results[check_class] = instance.run()
69207

70208
def _uses_roi_priors(self):
71209
"""Checks if the model uses ROI priors."""
@@ -102,22 +240,60 @@ def _has_custom_roi_priors(self):
102240
return True
103241
return False
104242

243+
def _compute_health_score(self) -> float:
244+
"""Computes the health score of the model.
245+
246+
Raises:
247+
ValueError: If any required checks are missing from the results.
248+
249+
Returns:
250+
The computed health score.
251+
"""
252+
missing_checks = [
253+
comp.check_type.__name__
254+
for comp in _HEALTH_SCORE_COMPONENTS
255+
if comp.is_required and comp.check_type not in self._results
256+
]
257+
if missing_checks:
258+
raise ValueError(
259+
"The following required checks results are missing:"
260+
f" {missing_checks}."
261+
)
262+
263+
scores_and_weights = [
264+
(
265+
comp.score_function(
266+
typing.cast(comp.result_type, self._results[comp.check_type])
267+
),
268+
comp.weight,
269+
)
270+
for comp in _HEALTH_SCORE_COMPONENTS
271+
if comp.check_type in self._results
272+
]
273+
274+
sum_score = sum(score * weight for score, weight in scores_and_weights)
275+
total_weight = sum(weight for _, weight in scores_and_weights)
276+
277+
return sum_score / total_weight if total_weight else 0.0
278+
105279
def run(self) -> results.ReviewSummary:
106280
"""Executes all checks and generates the final summary."""
107-
self._results.clear()
281+
self._results = {}
108282
self._run_and_handle(checks.ConvergenceCheck, configs.ConvergenceConfig())
109283

110284
# Stop if the model did not converge.
111285
if (
112286
self._results
113-
and self._results[0].case is results.ConvergenceCases.NOT_CONVERGED
287+
and self._results[checks.ConvergenceCheck].case
288+
is results.ConvergenceCases.NOT_CONVERGED
114289
):
115290
return results.ReviewSummary(
116291
overall_status=results.Status.FAIL,
117292
summary_message=(
118293
"Failed: Model did not converge. Other checks were skipped."
119294
),
120-
results=self._results,
295+
results=list(self._results.values()),
296+
health_score=0.0,
121297
)
122298

123299
# Run all other checks in sequence.
@@ -138,10 +314,11 @@ def run(self) -> results.ReviewSummary:
138314

139315
# Determine the final overall status.
140316
has_failures = any(
141-
res.case.status is results.Status.FAIL for res in self._results
317+
res.case.status is results.Status.FAIL for res in self._results.values()
142318
)
143319
has_reviews = any(
144-
res.case.status is results.Status.REVIEW for res in self._results
320+
res.case.status is results.Status.REVIEW
321+
for res in self._results.values()
145322
)
146323

147324
if has_failures and has_reviews:
@@ -167,5 +344,6 @@ def run(self) -> results.ReviewSummary:
167344
return results.ReviewSummary(
168345
overall_status=overall_status,
169346
summary_message=summary_message,
170-
results=self._results,
347+
results=list(self._results.values()),
348+
health_score=self._compute_health_score(),
171349
)

0 commit comments

Comments
 (0)