@@ -349,8 +349,12 @@ def plot_default_diagnostics(
349349 - Loss history (if training history is available).
350350 - Parameter recovery plots.
351351 - Calibration ECDF plots.
352+ - Coverage plots.
352353 - Z-score contraction plots.
353354
355+ Caution: For models with many parameters, plotting all marginal diagnostics becomes unwieldy. Consider
356+ providing `variables_keyes` for visualizing the diagnostics for subsets of the parameter space.
357+
354358 Parameters
355359 ----------
356360 test_data : Mapping[str, np.ndarray] or int
@@ -400,6 +404,7 @@ def plot_default_diagnostics(
400404 plot_fns = {
401405 "recovery" : bf_plots .recovery ,
402406 "calibration_ecdf" : bf_plots .calibration_ecdf ,
407+ "coverage" : bf_plots .coverage ,
403408 "z_score_contraction" : bf_plots .z_score_contraction ,
404409 }
405410
@@ -499,9 +504,10 @@ def compute_default_diagnostics(
499504 """
500505 Computes default diagnostic metrics to evaluate the quality of inference. The function computes several
501506 diagnostic metrics, including:
502- - Root Mean Squared Error (RMSE)
503- - Posterior contraction
504- - Calibration error
507+ - (Normalized) Root Mean Squared Error ((N)RMSE): summarizes the recovery plots
508+ - Log-gamma statistic - summarizes the ECDF calibration plots
509+ - Expected Calibration Error (ECE) - summarizes the coverage plots
510+ - Posterior contraction - partially summarizes the contraction plots
505511
506512 Parameters
507513 ----------
@@ -553,12 +559,12 @@ def compute_default_diagnostics(
553559 ** kwargs .get ("root_mean_squared_error_kwargs" , {}),
554560 )
555561
556- contraction = bf_metrics .posterior_contraction (
562+ log_gamma = bf_metrics .calibration_log_gamma (
557563 estimates = samples ,
558564 targets = test_data ,
559565 variable_keys = variable_keys ,
560566 variable_names = variable_names ,
561- ** kwargs .get ("posterior_contraction_kwargs " , {}),
567+ ** kwargs .get ("log_gamma_kwargs " , {}),
562568 )
563569
564570 calibration_errors = bf_metrics .calibration_error (
@@ -569,17 +575,26 @@ def compute_default_diagnostics(
569575 ** kwargs .get ("calibration_error_kwargs" , {}),
570576 )
571577
578+ contraction = bf_metrics .posterior_contraction (
579+ estimates = samples ,
580+ targets = test_data ,
581+ variable_keys = variable_keys ,
582+ variable_names = variable_names ,
583+ ** kwargs .get ("posterior_contraction_kwargs" , {}),
584+ )
585+
572586 if as_data_frame :
573587 metrics = pd .DataFrame (
574588 {
575589 root_mean_squared_error ["metric_name" ]: root_mean_squared_error ["values" ],
576- contraction ["metric_name" ]: contraction ["values" ],
590+ log_gamma ["metric_name" ]: log_gamma ["values" ],
577591 calibration_errors ["metric_name" ]: calibration_errors ["values" ],
592+ contraction ["metric_name" ]: contraction ["values" ],
578593 },
579594 index = variable_keys or root_mean_squared_error ["variable_names" ],
580595 ).T
581596 else :
582- metrics = (root_mean_squared_error , contraction , calibration_errors )
597+ metrics = (root_mean_squared_error , log_gamma , calibration_errors , contraction )
583598
584599 return metrics
585600
0 commit comments