|
| 1 | +from collections.abc import Sequence, Mapping, Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities |
| 6 | + |
| 7 | + |
| 8 | +def posterior_z_score( |
| 9 | + estimates: Mapping[str, np.ndarray] | np.ndarray, |
| 10 | + targets: Mapping[str, np.ndarray] | np.ndarray, |
| 11 | + variable_keys: Sequence[str] = None, |
| 12 | + variable_names: Sequence[str] = None, |
| 13 | + test_quantities: dict[str, Callable] = None, |
| 14 | + aggregation: Callable | None = np.median, |
| 15 | +) -> dict[str, any]: |
| 16 | + """ |
| 17 | + Computes the posterior z-score from prior to posterior for the given samples according to [1]: |
| 18 | +
|
| 19 | + post_z_score = (posterior_mean - true_parameters) / posterior_std |
| 20 | +
|
| 21 | + The score is adequate if it centers around zero and spreads roughly |
| 22 | + in the interval [-3, 3] |
| 23 | +
|
| 24 | + [1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). |
| 25 | + Toward a principled Bayesian workflow in cognitive science. |
| 26 | + Psychological methods, 26(1), 103. |
| 27 | +
|
| 28 | + Paper also available at https://arxiv.org/abs/1904.12765 |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + estimates : np.ndarray of shape (num_datasets, num_draws_post, num_variables) |
| 33 | + Posterior samples, comprising `num_draws_post` random draws from the posterior distribution |
| 34 | + for each data set from `num_datasets`. |
| 35 | + targets : np.ndarray of shape (num_datasets, num_variables) |
| 36 | + Prior samples, comprising `num_datasets` ground truths. |
| 37 | + variable_keys : Sequence[str], optional (default = None) |
| 38 | + Select keys from the dictionaries provided in estimates and targets. |
| 39 | + By default, select all keys. |
| 40 | + variable_names : Sequence[str], optional (default = None) |
| 41 | + Optional variable names to show in the output. |
| 42 | + test_quantities : dict or None, optional, default: None |
| 43 | + A dict that maps plot titles to functions that compute |
| 44 | + test quantities based on estimate/target draws. |
| 45 | +
|
| 46 | + The dict keys are automatically added to ``variable_keys`` |
| 47 | + and ``variable_names``. |
| 48 | + Test quantity functions are expected to accept a dict of draws with |
| 49 | + shape ``(batch_size, ...)`` as the first (typically only) |
| 50 | + positional argument and return an NumPy array of shape |
| 51 | + ``(batch_size,)``. |
| 52 | + The functions do not have to deal with an additional |
| 53 | + sample dimension, as appropriate reshaping is done internally. |
| 54 | + aggregation : callable or None, optional (default = np.median) |
| 55 | + Function to aggregate the PC across draws. Typically `np.mean` or `np.median`. |
| 56 | + If None is provided, the individual values are returned. |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + result : dict |
| 61 | + Dictionary containing: |
| 62 | +
|
| 63 | + - "values" : float or np.ndarray |
| 64 | + The (optionally aggregated) posterior z-score per variable |
| 65 | + - "metric_name" : str |
| 66 | + The name of the metric ("Posterior z-score"). |
| 67 | + - "variable_names" : str |
| 68 | + The (inferred) variable names. |
| 69 | +
|
| 70 | + Notes |
| 71 | + ----- |
| 72 | + Posterior z-score quantifies how far the posterior mean lies from |
| 73 | + the true generating parameter, in standard-error units. Values near 0 |
| 74 | + (in absolute terms) mean the posterior mean is close to the truth; |
| 75 | + large absolute values indicate substantial deviation. |
| 76 | + The sign shows the direction of the bias. |
| 77 | +
|
| 78 | + """ |
| 79 | + |
| 80 | + # Optionally, compute and prepend test quantities from draws |
| 81 | + if test_quantities is not None: |
| 82 | + updated_data = compute_test_quantities( |
| 83 | + targets=targets, |
| 84 | + estimates=estimates, |
| 85 | + variable_keys=variable_keys, |
| 86 | + variable_names=variable_names, |
| 87 | + test_quantities=test_quantities, |
| 88 | + ) |
| 89 | + variable_names = updated_data["variable_names"] |
| 90 | + variable_keys = updated_data["variable_keys"] |
| 91 | + estimates = updated_data["estimates"] |
| 92 | + targets = updated_data["targets"] |
| 93 | + |
| 94 | + samples = dicts_to_arrays( |
| 95 | + estimates=estimates, |
| 96 | + targets=targets, |
| 97 | + variable_keys=variable_keys, |
| 98 | + variable_names=variable_names, |
| 99 | + ) |
| 100 | + |
| 101 | + post_vars = samples["estimates"].var(axis=1, ddof=1) |
| 102 | + post_means = samples["estimates"].mean(axis=1) |
| 103 | + post_stds = np.sqrt(post_vars) |
| 104 | + prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1) |
| 105 | + z_score = (post_means - samples["targets"]) / post_stds |
| 106 | + if aggregation is not None: |
| 107 | + z_score = aggregation(z_score, axis=0) |
| 108 | + variable_names = samples["estimates"].variable_names |
| 109 | + return {"values": z_score, "metric_name": "Posterior z-score", "variable_names": variable_names} |
0 commit comments