diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index b212128f1..caef625cd 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -19,7 +19,9 @@ mc_confusion_matrix, mmd_hypothesis_test, pairs_posterior, + pairs_quantity, pairs_samples, + plot_quantity, recovery, recovery_from_estimates, z_score_contraction, diff --git a/bayesflow/diagnostics/metrics/posterior_contraction.py b/bayesflow/diagnostics/metrics/posterior_contraction.py index bc91da629..a8dffb922 100644 --- a/bayesflow/diagnostics/metrics/posterior_contraction.py +++ b/bayesflow/diagnostics/metrics/posterior_contraction.py @@ -10,7 +10,7 @@ def posterior_contraction( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, - aggregation: Callable = np.median, + aggregation: Callable | None = np.median, ) -> dict[str, any]: """ Computes the posterior contraction (PC) from prior to posterior for the given samples. @@ -27,8 +27,9 @@ def posterior_contraction( By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. - aggregation : callable, optional (default = np.median) + aggregation : callable or None, optional (default = np.median) Function to aggregate the PC across draws. Typically `np.mean` or `np.median`. + If None is provided, the individual values are returned. Returns ------- @@ -36,7 +37,7 @@ def posterior_contraction( Dictionary containing: - "values" : float or np.ndarray - The aggregated posterior contraction per variable + The (optionally aggregated) posterior contraction per variable - "metric_name" : str The name of the metric ("Posterior Contraction"). - "variable_names" : str @@ -59,6 +60,7 @@ def posterior_contraction( post_vars = samples["estimates"].var(axis=1, ddof=1) prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1) contraction = np.clip(1 - (post_vars / prior_vars), 0, 1) - contraction = aggregation(contraction, axis=0) + if aggregation is not None: + contraction = aggregation(contraction, axis=0) variable_names = samples["estimates"].variable_names return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names} diff --git a/bayesflow/diagnostics/plots/__init__.py b/bayesflow/diagnostics/plots/__init__.py index 0904af51e..fe260aa7e 100644 --- a/bayesflow/diagnostics/plots/__init__.py +++ b/bayesflow/diagnostics/plots/__init__.py @@ -6,6 +6,8 @@ from .mc_confusion_matrix import mc_confusion_matrix from .mmd_hypothesis_test import mmd_hypothesis_test from .pairs_posterior import pairs_posterior +from .pairs_quantity import pairs_quantity +from .plot_quantity import plot_quantity from .pairs_samples import pairs_samples from .recovery import recovery from .recovery_from_estimates import recovery_from_estimates diff --git a/bayesflow/diagnostics/plots/calibration_ecdf.py b/bayesflow/diagnostics/plots/calibration_ecdf.py index 5ea62b398..8592e89f1 100644 --- a/bayesflow/diagnostics/plots/calibration_ecdf.py +++ b/bayesflow/diagnostics/plots/calibration_ecdf.py @@ -1,9 +1,9 @@ from collections.abc import Callable, Mapping, Sequence import numpy as np -import keras import matplotlib.pyplot as plt +from ...utils.dict_utils import compute_test_quantities from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots from ...utils.ecdf import simultaneous_ecdf_bands from ...utils.ecdf.ranks import fractional_ranks, distance_ranks @@ -136,38 +136,17 @@ def calibration_ecdf( # Optionally, compute and prepend test quantities from draws if test_quantities is not None: - test_quantities_estimates = {} - test_quantities_targets = {} - - for key, test_quantity_fn in test_quantities.items(): - # Apply test_quantity_func to ground-truths - tq_targets = test_quantity_fn(data=targets) - test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1) - - # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape - num_conditions, num_samples = next(iter(estimates.values())).shape[:2] - flattened_estimates = keras.tree.map_structure( - lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:])) - if isinstance(t, np.ndarray) - else t, - estimates, - ) - flat_tq_estimates = test_quantity_fn(data=flattened_estimates) - test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1)) - - # Add custom test quantities to variable keys and names for plotting - # keys and names are set to the test_quantities dict keys - test_quantities_names = list(test_quantities.keys()) - - if variable_keys is None: - variable_keys = list(estimates.keys()) - - if isinstance(variable_names, list): - variable_names = test_quantities_names + variable_names - - variable_keys = test_quantities_names + variable_keys - estimates = test_quantities_estimates | estimates - targets = test_quantities_targets | targets + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] plot_data = prepare_plot_data( estimates=estimates, diff --git a/bayesflow/diagnostics/plots/pairs_quantity.py b/bayesflow/diagnostics/plots/pairs_quantity.py new file mode 100644 index 000000000..00fc24199 --- /dev/null +++ b/bayesflow/diagnostics/plots/pairs_quantity.py @@ -0,0 +1,262 @@ +from collections.abc import Callable, Sequence, Mapping + +import matplotlib +import matplotlib.pyplot as plt + +import numpy as np +import pandas as pd +import seaborn as sns + + +from .plot_quantity import _prepare_values + + +def pairs_quantity( + values: Mapping[str, np.ndarray] | np.ndarray | Callable, + targets: Mapping[str, np.ndarray] | np.ndarray, + *, + variable_keys: Sequence[str] = None, + variable_names: Sequence[str] = None, + estimates: Mapping[str, np.ndarray] | np.ndarray | None = None, + test_quantities: dict[str, Callable] = None, + height: float = 2.5, + cmap: str | matplotlib.colors.Colormap = "viridis", + alpha: float = 0.9, + markersize: float = 8.0, + marker: str = "o", + label: str = None, + label_fontsize: int = 14, + tick_fontsize: int = 12, + colorbar_label_fontsize: int = 14, + colorbar_tick_fontsize: int = 12, + colorbar_width: float = 1.8, + colorbar_height: float = 0.06, + colorbar_offset: float = 0.06, + vmin: float = None, + vmax: float = None, + default_name: str = "v", + **kwargs, +) -> sns.PairGrid: + """ + A pair plot function to plot quantities against their generating + parameter values. + + The value is indicated by a colormap. The marginal distribution for + each parameter is plotted on the diagonal. Each column displays the + values of corresponding to the parameter in the column. + + The function supports the following different combinations to pass + or compute the values: + + 1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables) + 2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names' + as provided by the metrics functions. Note that the functions have to be called + without aggregation to obtain value per dataset. + 3. pass a function to `values`, as well as `estimates`. The function should have the + signature fn(estimates, targets, [aggregation]) and return an object like the + `values` described in the previous options. + + Parameters + ---------- + values : dict[str, np.ndarray] | np.ndarray | Callable, + The value of the quantity to plot. One of the following: + + 1. an array of shape (num_datasets,) or (num_datasets, num_variables) + 2. a dictionary with the keys 'values', 'metric_name' and 'variable_names' + as provided by the metrics functions. Note that the functions have to be called + without aggregation to obtain value per dataset. + 3. a callable, requires passing `estimates` as well. The function should have the + signature fn(estimates, targets, [aggregation]) and return an object like the + ones described in the previous options. + targets : dict[str, np.ndarray] | np.ndarray, + The parameter values plotted on the axis. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. + By default, select all keys. + variable_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None + The posterior draws obtained from n_data_sets. Can only be supplied if + `values` is of type Callable. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + Can only be supplied if `values` is a function. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. + height : float, optional, default: 2.5 + The height of the pair plot + cmap : str or Colormap, default: "viridis" + The colormap for the plot. + alpha : float in [0, 1], optional, default: 0.9 + The opacity of the plot + markersize : float, optional, default: 8.0 + The marker size in points**2 for the scatter plot. + marker : str, optional, default: 'o' + The marker for the scatter plot. + label : str, optional, default: None + Label for the dataset to plot. + label_fontsize : int, optional, default: 14 + The font size of the x and y-label texts (parameter names) + tick_fontsize : int, optional, default: 12 + The font size of the axis tick labels + colorbar_label_fontsize : int, optional, default: 14 + The font size of the colorbar label + colorbar_tick_fontsize : int, optional, default: 12 + The font size of the colorbar tick labels + colorbar_width : float, optional, default: 1.8 + The width of the colorbar in inches + colorbar_height : float, optional, default: 0.06 + The height of the colorbar in inches + colorbar_offset : float, optional, default: 0.06 + The vertical offset of the colorbar in inches + vmin : float, optional, default: None + Minimum value for the colormap. If None, the minimum value is + determined from `values`. + vmax : float, optional, default: None + Maximum value for the colormap. If None, the maximum value is + determined from `values`. + default_name : str, optional (default = "v") + The default name to use for estimates if None provided + **kwargs : dict, optional + Additional keyword arguments passed to the sns.PairGrid constructor + + Returns + ------- + plt.Figure + The figure instance + + Raises + ------ + ValueError + If a callable is supplied as `values`, but `estimates` is None. + """ + + if isinstance(values, Callable) and estimates is None: + raise ValueError("Supplied a callable as `values`, but no `estimates`.") + if not isinstance(values, Callable) and test_quantities is not None: + raise ValueError( + "Supplied `test_quantities`, but `values` is not a function. " + "As the values have to be calculated for the test quantities, " + "passing a function is required." + ) + + d = _prepare_values( + values=values, + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + label=label, + default_name=default_name, + ) + (values, targets, variable_keys, variable_names, test_quantities, label) = ( + d["values"], + d["targets"], + d["variable_keys"], + d["variable_names"], + d["test_quantities"], + d["label"], + ) + + # Convert samples to pd.DataFrame + data_to_plot = pd.DataFrame(targets, columns=variable_names) + + # initialize plot + g = sns.PairGrid( + data_to_plot, + height=height, + vars=variable_names, + **kwargs, + ) + + vmin = values.min() if vmin is None else vmin + vmax = values.max() if vmax is None else vmax + + # Generate grids + dim = g.axes.shape[0] + for i in range(dim): + for j in range(dim): + # if one value for each variable is supplied, use it for the corresponding column + row_values = values[:, j] if values.ndim == 2 else values + + if i == j: + ax = g.axes[i, j].twinx() + ax.scatter( + targets[:, i], + values[:, i], + c=row_values, + cmap=cmap, + s=markersize, + marker=marker, + vmin=vmin, + vmax=vmax, + alpha=alpha, + ) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.tick_params(axis="both", which="major", labelsize=tick_fontsize) + ax.tick_params(axis="both", which="minor", labelsize=tick_fontsize) + ax.set_ylim(vmin, vmax) + + if i > 0: + g.axes[i, j].get_yaxis().set_visible(False) + g.axes[i, j].spines["left"].set_visible(False) + if i == dim - 1: + ax.set_ylabel(label, size=label_fontsize) + else: + g.axes[i, j].grid(alpha=0.5) + g.axes[i, j].set_axisbelow(True) + g.axes[i, j].scatter( + targets[:, j], + targets[:, i], + c=row_values, + cmap=cmap, + s=markersize, + vmin=vmin, + vmax=vmax, + alpha=alpha, + marker=marker, + ) + + def inches_to_figure(fig, values): + return fig.transFigure.inverted().transform(fig.dpi_scale_trans.transform(values)) + + # position and draw colorbar + _, yoffset = inches_to_figure(g.figure, [0, colorbar_offset]) + cwidth, cheight = inches_to_figure(g.figure, [colorbar_width, colorbar_offset]) + cax = g.figure.add_axes([0.5 - cwidth / 2, -yoffset - cheight, cwidth, cheight]) + + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + cbar = plt.colorbar( + matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap), + cax=cax, + location="bottom", + label=label, + alpha=alpha, + ) + + cbar.set_label(label, size=colorbar_label_fontsize) + cax.tick_params(labelsize=colorbar_tick_fontsize) + + dim = g.axes.shape[0] + for i in range(dim): + # Modify tick sizes + for j in range(i + 1): + g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize) + g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize) + + # adjust the font size of labels + # the labels themselves remain the same as before, i.e., variable_names + g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize) + g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize) + + return g diff --git a/bayesflow/diagnostics/plots/plot_quantity.py b/bayesflow/diagnostics/plots/plot_quantity.py new file mode 100644 index 000000000..2fd0c6841 --- /dev/null +++ b/bayesflow/diagnostics/plots/plot_quantity.py @@ -0,0 +1,282 @@ +from collections.abc import Callable, Sequence, Mapping + +import matplotlib.pyplot as plt +import numpy as np + +from bayesflow.utils.dict_utils import make_variable_array, dicts_to_arrays, filter_kwargs, compute_test_quantities +from bayesflow.utils.plot_utils import ( + add_titles_and_labels, + make_figure, + set_layout, + prettify_subplots, +) +from bayesflow.utils.validators import check_estimates_prior_shapes + + +def plot_quantity( + values: Mapping[str, np.ndarray] | np.ndarray | Callable, + targets: Mapping[str, np.ndarray] | np.ndarray, + *, + variable_keys: Sequence[str] = None, + variable_names: Sequence[str] = None, + estimates: Mapping[str, np.ndarray] | np.ndarray | None = None, + test_quantities: dict[str, Callable] = None, + figsize: Sequence[int] = None, + label_fontsize: int = 16, + title_fontsize: int = 18, + tick_fontsize: int = 12, + color: str = "#132a70", + markersize: float = 25.0, + marker: str = "o", + alpha: float = 0.5, + xlabel: str = "Ground truth", + ylabel: str = "", + num_col: int = None, + num_row: int = None, + default_name: str = "v", +) -> plt.Figure: + """ + Plot a quantity as a function of a variable for each variable key. + + The function supports the following different combinations to pass + or compute the values: + + 1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables) + 2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names' + as provided by the metrics functions. Note that the functions have to be called + without aggregation to obtain value per dataset. + 3. pass a function to `values`, as well as `estimates`. The function should have the + signature fn(estimates, targets, [aggregation]) and return an object like the + `values` described in the previous options. + + Parameters + ---------- + values : dict[str, np.ndarray] | np.ndarray | Callable, + The value of the quantity to plot. One of the following: + + 1. an array of shape (num_datasets,) or (num_datasets, num_variables) + 2. a dictionary with the keys 'values', 'metric_name' and 'variable_names' + as provided by the metrics functions. Note that the functions have to be called + without aggregation to obtain value per dataset. + 3. a callable, requires passing `estimates` as well. The function should have the + signature fn(estimates, targets, [aggregation]) and return an object like the + ones described in the previous options. + targets : dict[str, np.ndarray] | np.ndarray, + The parameter values plotted on the axis. + variable_keys : list or None, optional, default: None + Select keys from the dictionary provided in samples. + By default, select all keys. + variable_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None + The posterior draws obtained from n_data_sets. Can only be supplied if + `values` is of type Callable. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + Can only be supplied if `values` is a function. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. + figsize : tuple or None, optional, default : None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + color : str, optional, default: '#8f2727' + The color for the true vs. estimated scatter points and error bars + markersize : float, optional, default: 25.0 + The marker size in points**2 for the scatter plot. + marker : str, optional, default: 'o' + The marker for the scatter plot. + alpha : float, default: 0.5 + The opacity for the scatter plot + num_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + num_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + default_name : str, optional (default = "v") + The default name to use for estimates if None provided + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``estimates`` and ``targets``. + """ + + if isinstance(values, Callable) and estimates is None: + raise ValueError("Supplied a callable as `values`, but no `estimates`.") + if not isinstance(values, Callable) and test_quantities is not None: + raise ValueError( + "Supplied `test_quantities`, but `values` is not a function. " + "As the values have to be calculated for the test quantities, " + "passing a function is required." + ) + + d = _prepare_values( + values=values, + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + label=None, + default_name=default_name, + ) + (values, targets, variable_keys, variable_names, test_quantities, _) = ( + d["values"], + d["targets"], + d["variable_keys"], + d["variable_names"], + d["test_quantities"], + d["label"], + ) + + # store variable information at the top level for easy access + num_variables = len(variable_names) + + # Configure layout + num_row, num_col = set_layout(num_variables, num_row, num_col) + + # Initialize figure + fig, axes = make_figure(num_row, num_col, figsize=figsize) + + # Loop and plot + for i, ax in enumerate(axes.flat): + if i >= num_variables: + break + + ax.scatter(targets[:, i], values[:, i], color=color, alpha=alpha, s=markersize, marker=marker) + + prettify_subplots(axes, num_subplots=num_variables, tick_fontsize=tick_fontsize) + + # Add labels, titles, and set font sizes + add_titles_and_labels( + axes=axes, + num_row=num_row, + num_col=num_col, + title=variable_names, + xlabel=xlabel, + ylabel=ylabel, + title_fontsize=title_fontsize, + label_fontsize=label_fontsize, + ) + + fig.tight_layout() + return fig + + +def _prepare_values( + *, + values: Mapping[str, np.ndarray] | np.ndarray | Callable, + targets: Mapping[str, np.ndarray] | np.ndarray, + estimates: Mapping[str, np.ndarray] | np.ndarray | None, + variable_keys: Sequence[str], + variable_names: Sequence[str], + test_quantities: dict[str, Callable], + label: str | None, + default_name: str, +): + """ + Private helper function to compute/extract the values required for plotting + a quantity. + + Refer to pairs_quantity and plot_quantity for details. + """ + is_values_callable = isinstance(values, Callable) + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + + if estimates is not None: + if is_values_callable: + values = values(estimates=estimates, targets=targets, **filter_kwargs({"aggregation": None}, values)) + + data = dicts_to_arrays( + estimates=estimates, + targets=targets, + variable_keys=variable_keys, + variable_names=variable_names, + default_name=default_name, + ) + check_estimates_prior_shapes(data["estimates"], data["targets"]) + estimates = data["estimates"] + targets = data["targets"] + + variable_keys = variable_keys or estimates.variable_keys + if test_quantities is None: + variable_names = variable_names or estimates.variable_names + + if all([key in values for key in ["values", "metric_name", "variable_names"]]): + # output of a metric function + label = values["metric_name"] if label is None else label + variable_names = variable_names or values["variable_names"] + values = values["values"] + + if hasattr(values, "variable_keys"): + variable_keys = variable_keys or values.variable_keys + if hasattr(values, "variable_names") and test_quantities is None: + variable_names = variable_names or values.variable_names + + try: + targets = make_variable_array( + targets, + variable_keys=variable_keys, + variable_names=variable_names, + default_name=default_name, + ) + except ValueError: + raise ValueError( + "Length of 'variable_names' and number of variables do not match. " + "Did you forget to specify `variable_keys`?" + ) + variable_names = targets.variable_names + variable_keys = targets.variable_keys + + if values.ndim == 1: + values = values[:, None].repeat(len(variable_names), axis=-1) + + try: + values = make_variable_array( + values, + variable_keys=variable_keys, + variable_names=variable_names, + default_name=default_name, + ) + except ValueError: + raise ValueError( + "Length of 'variable_names' and number of variables do not match. " + "Did you forget to specify `variable_keys`?" + ) + + return { + "values": values, + "targets": targets, + "variable_keys": variable_keys, + "variable_names": variable_names, + "test_quantities": test_quantities, + "label": label, + } diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 5d085f0ec..c94429559 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -219,7 +219,7 @@ def make_variable_array( # reuse existing variable keys and names if contained in x if variable_names is None: variable_names = x.variable_names - if variable_keys in None: + if variable_keys is None: variable_keys = x.variable_keys # use default names if not otherwise specified @@ -344,3 +344,55 @@ def squeeze_inner_estimates_dict(estimates): return estimates["value"] else: return estimates + + +def compute_test_quantities( + targets: Mapping[str, np.ndarray] | np.ndarray, + estimates: Mapping[str, np.ndarray] | np.ndarray | None = None, + variable_keys: Sequence[str] = None, + variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, +): + """Compute additional test quantities for given targets and estimates.""" + import keras + + test_quantities_estimates = {} if estimates is not None else None + test_quantities_targets = {} + + for key, test_quantity_fn in test_quantities.items(): + # Apply test_quantity_func to ground-truths + tq_targets = test_quantity_fn(data=targets) + test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1) + + if estimates is not None: + # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape + num_conditions, num_samples = next(iter(estimates.values())).shape[:2] + flattened_estimates = keras.tree.map_structure( + lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:])) + if isinstance(t, np.ndarray) + else t, + estimates, + ) + flat_tq_estimates = test_quantity_fn(data=flattened_estimates) + test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1)) + + # Add custom test quantities to variable keys and names for plotting + # keys and names are set to the test_quantities dict keys + test_quantities_names = list(test_quantities.keys()) + + if variable_keys is None: + variable_keys = list(estimates.keys() if estimates is not None else targets.keys()) + if isinstance(variable_names, list): + variable_names = test_quantities_names + variable_names + + variable_keys = test_quantities_names + variable_keys + if estimates is not None: + estimates = test_quantities_estimates | estimates + targets = test_quantities_targets | targets + + return { + "variable_keys": variable_keys, + "estimates": estimates, + "targets": targets, + "variable_names": variable_names, + } diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index 92de891c4..35d9276d3 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -43,6 +43,9 @@ def test_posterior_contraction(random_estimates, random_targets): assert out["values"].shape == (num_variables(random_estimates),) assert out["metric_name"] == "Posterior Contraction" assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] + # test without aggregation + out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets, aggregation=None) + assert out["values"].shape == (random_estimates["sigma"].shape[0], num_variables(random_estimates)) def test_root_mean_squared_error(random_estimates, random_targets): diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index 952fe4002..e2f1c09f2 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -161,6 +161,102 @@ def test_pairs_posterior(random_estimates, random_targets, random_priors): ) +def test_pairs_quantity(random_estimates, random_targets, random_priors): + # test test_quantities and label assignment + key = next(iter(random_estimates.keys())) + test_quantities = { + "a": lambda data: np.sum(data[key], axis=-1), + "b": lambda data: np.prod(data[key], axis=-1), + } + out = bf.diagnostics.plots.pairs_quantity( + values=bf.diagnostics.posterior_contraction, + estimates=random_estimates, + targets=random_targets, + test_quantities=test_quantities, + ) + + num_vars = num_variables(random_estimates) + len(test_quantities) + assert out.axes.shape == (num_vars, num_vars) + assert out.axes[0, 0].get_ylabel() == "a" + assert out.axes[2, 0].get_ylabel() == "beta_0" + assert out.axes[4, 4].get_xlabel() == "sigma" + + values = bf.diagnostics.posterior_contraction(estimates=random_estimates, targets=random_targets, aggregation=None) + + bf.diagnostics.plots.pairs_quantity( + values, + targets=random_targets, + ) + + raw_values = np.random.normal(size=values["values"].shape) + out = bf.diagnostics.plots.pairs_quantity(raw_values, targets=random_targets, variable_keys=["beta", "sigma"]) + assert out.axes.shape == (3, 3) + + with pytest.raises(ValueError): + bf.diagnostics.plots.pairs_quantity(raw_values, targets=random_targets) + + with pytest.raises(ValueError): + bf.diagnostics.plots.pairs_quantity( + values=values, + estimates=random_estimates, + targets=random_targets, + test_quantities=test_quantities, + ) + + with pytest.raises(ValueError): + bf.diagnostics.plots.pairs_quantity( + values=bf.diagnostics.posterior_contraction, + targets=random_targets, + ) + + +def test_plot_quantity(random_estimates, random_targets, random_priors): + # test test_quantities and label assignment + key = next(iter(random_estimates.keys())) + test_quantities = { + "a": lambda data: np.sum(data[key], axis=-1), + "b": lambda data: np.prod(data[key], axis=-1), + } + out = bf.diagnostics.plots.plot_quantity( + values=bf.diagnostics.posterior_contraction, + estimates=random_estimates, + targets=random_targets, + test_quantities=test_quantities, + ) + + num_vars = num_variables(random_estimates) + len(test_quantities) + assert len(out.axes) == num_vars + assert out.axes[0].title._text == "a" + + values = bf.diagnostics.posterior_contraction(estimates=random_estimates, targets=random_targets, aggregation=None) + + bf.diagnostics.plots.plot_quantity( + values, + targets=random_targets, + ) + + raw_values = np.random.normal(size=values["values"].shape) + out = bf.diagnostics.plots.plot_quantity(raw_values, targets=random_targets, variable_keys=["beta", "sigma"]) + assert len(out.axes) == 3 + + with pytest.raises(ValueError): + bf.diagnostics.plots.plot_quantity(raw_values, targets=random_targets) + + with pytest.raises(ValueError): + bf.diagnostics.plots.plot_quantity( + values=values, + estimates=random_estimates, + targets=random_targets, + test_quantities=test_quantities, + ) + + with pytest.raises(ValueError): + bf.diagnostics.plots.plot_quantity( + values=bf.diagnostics.posterior_contraction, + targets=random_targets, + ) + + def test_mc_calibration(pred_models, true_models, model_names): out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names, markersize=4) assert len(out.axes) == pred_models.shape[-1]