1414
1515"""Implementation of the runner of the Model Quality Checks."""
1616
17+ from collections .abc import MutableMapping
18+ import dataclasses
1719import typing
1820
1921import immutabledict
2022from meridian import constants
2123from meridian .analysis import analyzer as analyzer_module
2224from meridian .analysis .review import checks
2325from meridian .analysis .review import configs
26+ from meridian .analysis .review import constants as review_constants
2427from meridian .analysis .review import results
2528from meridian .model import prior_distribution
26-
29+ import numpy as np
2730
2831CheckType = typing .Type [checks .BaseCheck ]
2932ConfigInstance = configs .BaseConfig
3841})
3942
4043
44+ def _get_baseline_score (
45+ baseline_check_result : results .BaselineCheckResult ,
46+ ) -> float :
47+ """Returns the score of the Baseline check."""
48+ negative_baseline_prob = baseline_check_result .negative_baseline_prob
49+ baseline_config = baseline_check_result .config
50+ review_threshold = baseline_config .negative_baseline_prob_review_threshold
51+ fail_threshold = baseline_config .negative_baseline_prob_fail_threshold
52+
53+ return 100.0 * (
54+ 1.0
55+ - np .clip (
56+ (negative_baseline_prob - review_threshold )
57+ / (fail_threshold - review_threshold ),
58+ 0 ,
59+ 1 ,
60+ )
61+ )
62+
63+
64+ def _get_bayesian_ppp_score (
65+ bayesian_ppp_check_result : results .BayesianPPPCheckResult ,
66+ ) -> float :
67+ """Returns the score of the Bayesian PPP check."""
68+ bayesian_ppp = bayesian_ppp_check_result .bayesian_ppp
69+ bayesian_ppp_config = bayesian_ppp_check_result .config
70+ ppp_threshold = bayesian_ppp_config .ppp_threshold
71+ return 100.0 if bayesian_ppp > ppp_threshold else 0.0
72+
73+
74+ def _get_gof_score (
75+ goodness_of_fit_check_result : results .GoodnessOfFitCheckResult ,
76+ ) -> float :
77+ """Returns the score of the Goodness of Fit check."""
78+ r_squared = goodness_of_fit_check_result .metrics .r_squared
79+ return 100.0 / (
80+ 1
81+ + np .exp (
82+ - review_constants .R2_STEEPNESS
83+ * (r_squared - review_constants .R2_MIDPOINT )
84+ )
85+ )
86+
87+
88+ def _get_pps_score (
89+ prior_posterior_shift_check_result : results .PriorPosteriorShiftCheckResult ,
90+ ) -> float :
91+ """Returns the score of the Prior-Posterior Shift check."""
92+ prior_posterior_shift_ratio = len (
93+ prior_posterior_shift_check_result .no_shift_channels
94+ ) / len (prior_posterior_shift_check_result .channel_results )
95+ return (
96+ 100.0
97+ * (1.0 - np .clip (prior_posterior_shift_ratio , 0 , 1 ))
98+ ** review_constants .FAIL_RATIO_POWER
99+ )
100+
101+
102+ def _get_roi_consistency_score (
103+ roi_consistency_check_result : results .ROIConsistencyCheckResult ,
104+ ) -> float :
105+ """Returns the score of the ROI Consistency check."""
106+ roi_consistency_failure_ratio = sum (
107+ 1
108+ for r in roi_consistency_check_result .channel_results
109+ if r .case .status != results .Status .PASS
110+ ) / len (roi_consistency_check_result .channel_results )
111+ return (
112+ 100.0
113+ * (1.0 - np .clip (roi_consistency_failure_ratio , 0 , 1 ))
114+ ** review_constants .FAIL_RATIO_POWER
115+ )
116+
117+
118+ @dataclasses .dataclass (frozen = True )
119+ class _HealthScoreComponent :
120+ """A component used in the calculation of the overall health score.
121+
122+ Attributes:
123+ check_type: The class of the check this component represents.
124+ score_function: A callable that takes the check result and returns a float
125+ score.
126+ result_type: The expected type of the result object for this check.
127+ weight: The weight of this component in the overall health score
128+ calculation.
129+ is_required: Whether this check is required to be present for the health
130+ score to be computed.
131+ """
132+
133+ check_type : CheckType
134+ score_function : typing .Callable [[typing .Any ], float ]
135+ result_type : typing .Type [results .CheckResult ]
136+ weight : float
137+ is_required : bool
138+
139+
140+ _HEALTH_SCORE_COMPONENTS = (
141+ _HealthScoreComponent (
142+ check_type = checks .BaselineCheck ,
143+ score_function = _get_baseline_score ,
144+ result_type = results .BaselineCheckResult ,
145+ weight = review_constants .HEALTH_SCORE_WEIGHT_BASELINE ,
146+ is_required = True ,
147+ ),
148+ _HealthScoreComponent (
149+ check_type = checks .BayesianPPPCheck ,
150+ score_function = _get_bayesian_ppp_score ,
151+ result_type = results .BayesianPPPCheckResult ,
152+ weight = review_constants .HEALTH_SCORE_WEIGHT_BAYESIAN_PPP ,
153+ is_required = True ,
154+ ),
155+ _HealthScoreComponent (
156+ check_type = checks .GoodnessOfFitCheck ,
157+ score_function = _get_gof_score ,
158+ result_type = results .GoodnessOfFitCheckResult ,
159+ weight = review_constants .HEALTH_SCORE_WEIGHT_GOF ,
160+ is_required = True ,
161+ ),
162+ _HealthScoreComponent (
163+ check_type = checks .PriorPosteriorShiftCheck ,
164+ score_function = _get_pps_score ,
165+ result_type = results .PriorPosteriorShiftCheckResult ,
166+ weight = review_constants .HEALTH_SCORE_WEIGHT_PRIOR_POSTERIOR_SHIFT ,
167+ is_required = False ,
168+ ),
169+ _HealthScoreComponent (
170+ check_type = checks .ROIConsistencyCheck ,
171+ score_function = _get_roi_consistency_score ,
172+ result_type = results .ROIConsistencyCheckResult ,
173+ weight = review_constants .HEALTH_SCORE_WEIGHT_ROI_CONSISTENCY ,
174+ is_required = False ,
175+ ),
176+ )
177+
178+
41179class ModelReviewer :
42180 """A tool for executing a series of quality checks on a Meridian model.
43181
@@ -57,15 +195,15 @@ def __init__(
57195 meridian ,
58196 ):
59197 self ._meridian = meridian
60- self ._results : list [ results .CheckResult ] = []
198+ self ._results : MutableMapping [ CheckType , results .CheckResult ] = {}
61199 self ._analyzer = analyzer_module .Analyzer (
62200 model_context = meridian .model_context ,
63201 inference_data = meridian .inference_data ,
64202 )
65203
66- def _run_and_handle (self , check_class , config ):
67- instance = check_class (self ._meridian , self ._analyzer , config ) # pytype: disable=not-instantiable
68- self ._results . append ( instance .run () )
204+ def _run_and_handle (self , check_class : CheckType , config : configs . BaseConfig ):
205+ instance : checks . BaseCheck = check_class (self ._meridian , self ._analyzer , config ) # pytype: disable=not-instantiable
206+ self ._results [ check_class ] = instance .run ()
69207
70208 def _uses_roi_priors (self ):
71209 """Checks if the model uses ROI priors."""
@@ -102,22 +240,60 @@ def _has_custom_roi_priors(self):
102240 return True
103241 return False
104242
243+ def _compute_health_score (self ) -> float :
244+ """Computes the health score of the model.
245+
246+ Raises:
247+ ValueError: If any required checks are missing from the results.
248+
249+ Returns:
250+ The computed health score.
251+ """
252+ missing_checks = [
253+ comp .check_type .__name__
254+ for comp in _HEALTH_SCORE_COMPONENTS
255+ if comp .is_required and comp .check_type not in self ._results
256+ ]
257+ if missing_checks :
258+ raise ValueError (
259+ "The following required checks results are missing:"
260+ f" { missing_checks } ."
261+ )
262+
263+ scores_and_weights = [
264+ (
265+ comp .score_function (
266+ typing .cast (comp .result_type , self ._results [comp .check_type ])
267+ ),
268+ comp .weight ,
269+ )
270+ for comp in _HEALTH_SCORE_COMPONENTS
271+ if comp .check_type in self ._results
272+ ]
273+
274+ sum_score = sum (score * weight for score , weight in scores_and_weights )
275+ total_weight = sum (weight for _ , weight in scores_and_weights )
276+
277+ return sum_score / total_weight if total_weight else 0.0
278+
105279 def run (self ) -> results .ReviewSummary :
106280 """Executes all checks and generates the final summary."""
107- self ._results . clear ()
281+ self ._results = {}
108282 self ._run_and_handle (checks .ConvergenceCheck , configs .ConvergenceConfig ())
109283
110284 # Stop if the model did not converge.
111285 if (
112286 self ._results
113- and self ._results [0 ].case is results .ConvergenceCases .NOT_CONVERGED
287+ and self ._results [checks .ConvergenceCheck ].case
288+ is results .ConvergenceCases .NOT_CONVERGED
114289 ):
115290 return results .ReviewSummary (
116291 overall_status = results .Status .FAIL ,
117292 summary_message = (
118293 "Failed: Model did not converge. Other checks were skipped."
119294 ),
120- results = self ._results ,
295+ results = list (self ._results .values ()),
296+ health_score = 0.0 ,
121297 )
122298
123299 # Run all other checks in sequence.
@@ -138,10 +314,11 @@ def run(self) -> results.ReviewSummary:
138314
139315 # Determine the final overall status.
140316 has_failures = any (
141- res .case .status is results .Status .FAIL for res in self ._results
317+ res .case .status is results .Status .FAIL for res in self ._results . values ()
142318 )
143319 has_reviews = any (
144- res .case .status is results .Status .REVIEW for res in self ._results
320+ res .case .status is results .Status .REVIEW
321+ for res in self ._results .values ()
145322 )
146323
147324 if has_failures and has_reviews :
@@ -167,5 +344,6 @@ def run(self) -> results.ReviewSummary:
167344 return results .ReviewSummary (
168345 overall_status = overall_status ,
169346 summary_message = summary_message ,
170- results = self ._results ,
347+ results = list (self ._results .values ()),
348+ health_score = self ._compute_health_score (),
171349 )
0 commit comments