From 3b761b250a7480a80f8874c6b9fe23a40bc44934 Mon Sep 17 00:00:00 2001 From: Svenja Jedhoff Date: Wed, 15 Oct 2025 14:33:43 +0200 Subject: [PATCH 1/6] Adding custom test quantities to diagnostics --- .../diagnostics/metrics/calibration_error.py | 28 +++++++++++++++- .../metrics/calibration_log_gamma.py | 32 +++++++++++++++++-- .../metrics/posterior_contraction.py | 30 ++++++++++++++++- .../metrics/root_mean_squared_error.py | 29 ++++++++++++++++- .../plots/calibration_histogram.py | 30 ++++++++++++++++- bayesflow/diagnostics/plots/coverage.py | 30 ++++++++++++++++- bayesflow/diagnostics/plots/recovery.py | 30 ++++++++++++++++- .../diagnostics/plots/z_score_contraction.py | 32 +++++++++++++++++-- 8 files changed, 231 insertions(+), 10 deletions(-) diff --git a/bayesflow/diagnostics/metrics/calibration_error.py b/bayesflow/diagnostics/metrics/calibration_error.py index 1d298370c..0b54c257b 100644 --- a/bayesflow/diagnostics/metrics/calibration_error.py +++ b/bayesflow/diagnostics/metrics/calibration_error.py @@ -2,7 +2,7 @@ import numpy as np -from ...utils.dict_utils import dicts_to_arrays +from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities def calibration_error( @@ -10,6 +10,7 @@ def calibration_error( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, resolution: int = 20, aggregation: Callable = np.median, min_quantile: float = 0.005, @@ -32,6 +33,18 @@ def calibration_error( By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. resolution : int, optional, default: 20 The number of credibility intervals (CIs) to consider aggregation : callable or None, optional, default: np.median @@ -55,6 +68,19 @@ def calibration_error( The (inferred) variable names. """ + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + samples = dicts_to_arrays( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/metrics/calibration_log_gamma.py b/bayesflow/diagnostics/metrics/calibration_log_gamma.py index 54551c857..a25cf900f 100644 --- a/bayesflow/diagnostics/metrics/calibration_log_gamma.py +++ b/bayesflow/diagnostics/metrics/calibration_log_gamma.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence import numpy as np from scipy.stats import binom -from ...utils.dict_utils import dicts_to_arrays +from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities def calibration_log_gamma( @@ -11,6 +11,7 @@ def calibration_log_gamma( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, num_null_draws: int = 1000, quantile: float = 0.05, ): @@ -41,6 +42,18 @@ def calibration_log_gamma( By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. quantile : float in (0, 1), optional, default 0.05 The quantile from the null distribution to be used as a threshold. A lower quantile increases sensitivity to deviations from uniformity. @@ -57,6 +70,21 @@ def calibration_log_gamma( - "variable_names" : str The (inferred) variable names. """ + + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + samples = dicts_to_arrays( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/metrics/posterior_contraction.py b/bayesflow/diagnostics/metrics/posterior_contraction.py index a8dffb922..1004bda40 100644 --- a/bayesflow/diagnostics/metrics/posterior_contraction.py +++ b/bayesflow/diagnostics/metrics/posterior_contraction.py @@ -2,7 +2,7 @@ import numpy as np -from ...utils.dict_utils import dicts_to_arrays +from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities def posterior_contraction( @@ -10,6 +10,7 @@ def posterior_contraction( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, aggregation: Callable | None = np.median, ) -> dict[str, any]: """ @@ -27,6 +28,18 @@ def posterior_contraction( By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. aggregation : callable or None, optional (default = np.median) Function to aggregate the PC across draws. Typically `np.mean` or `np.median`. If None is provided, the individual values are returned. @@ -50,6 +63,21 @@ def posterior_contraction( indicate low contraction. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + + samples = dicts_to_arrays( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/metrics/root_mean_squared_error.py b/bayesflow/diagnostics/metrics/root_mean_squared_error.py index 7c3c6305a..d8209ffac 100644 --- a/bayesflow/diagnostics/metrics/root_mean_squared_error.py +++ b/bayesflow/diagnostics/metrics/root_mean_squared_error.py @@ -2,7 +2,7 @@ import numpy as np -from ...utils.dict_utils import dicts_to_arrays +from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities def root_mean_squared_error( @@ -10,6 +10,7 @@ def root_mean_squared_error( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, normalize: str | None = "range", aggregation: Callable = np.median, ) -> dict[str, any]: @@ -28,6 +29,18 @@ def root_mean_squared_error( By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to show in the output. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. normalize : str or None, optional (default = "range") Whether to normalize the RMSE using statistics of the prior samples. Possible options are ("mean", "range", "median", "iqr", "std", None) @@ -52,6 +65,20 @@ def root_mean_squared_error( The (inferred) variable names. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + samples = dicts_to_arrays( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/plots/calibration_histogram.py b/bayesflow/diagnostics/plots/calibration_histogram.py index 7bd4ce90b..1420f7a5f 100644 --- a/bayesflow/diagnostics/plots/calibration_histogram.py +++ b/bayesflow/diagnostics/plots/calibration_histogram.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence, Mapping +from collections.abc import Callable, Mapping, Sequence import matplotlib.pyplot as plt import numpy as np @@ -8,6 +8,7 @@ from bayesflow.utils import logging from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots +from bayesflow.utils.dict_utils import compute_test_quantities def calibration_histogram( @@ -15,6 +16,7 @@ def calibration_histogram( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, figsize: Sequence[float] = None, num_bins: int = 10, binomial_interval: float = 0.99, @@ -46,6 +48,18 @@ def calibration_histogram( By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None num_bins : int, optional, default: 10 @@ -75,6 +89,20 @@ def calibration_histogram( If there is a deviation form the expected shapes of `estimates` and `targets`. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + plot_data = prepare_plot_data( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/plots/coverage.py b/bayesflow/diagnostics/plots/coverage.py index c632b394f..65f370cbe 100644 --- a/bayesflow/diagnostics/plots/coverage.py +++ b/bayesflow/diagnostics/plots/coverage.py @@ -1,9 +1,10 @@ -from collections.abc import Sequence, Mapping +from collections.abc import Callable, Sequence, Mapping import matplotlib.pyplot as plt import numpy as np from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage +from bayesflow.utils.dict_utils import compute_test_quantities def coverage( @@ -12,6 +13,7 @@ def coverage( difference: bool = False, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, figsize: Sequence[int] = None, label_fontsize: int = 16, legend_fontsize: int = 14, @@ -50,6 +52,18 @@ def coverage( By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. figsize : tuple or None, optional, default: None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 @@ -80,6 +94,20 @@ def coverage( """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index f591a4284..864681fac 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -5,13 +5,14 @@ from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric from bayesflow.utils.numpy_utils import credible_interval - +from bayesflow.utils.dict_utils import compute_test_quantities def recovery( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, point_agg: Callable = np.median, uncertainty_agg: Callable = credible_interval, point_agg_kwargs: dict = None, @@ -58,6 +59,18 @@ def recovery( By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. point_agg : callable, optional, default: median Function to compute point estimates. uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95% @@ -104,6 +117,21 @@ def recovery( If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + + # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, diff --git a/bayesflow/diagnostics/plots/z_score_contraction.py b/bayesflow/diagnostics/plots/z_score_contraction.py index c013889a0..f80116f62 100644 --- a/bayesflow/diagnostics/plots/z_score_contraction.py +++ b/bayesflow/diagnostics/plots/z_score_contraction.py @@ -1,16 +1,17 @@ -from collections.abc import Sequence, Mapping +from collections.abc import Callable, Sequence, Mapping import matplotlib.pyplot as plt import numpy as np from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots - +from bayesflow.utils.dict_utils import compute_test_quantities def z_score_contraction( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, figsize: Sequence[int] = None, label_fontsize: int = 16, title_fontsize: int = 18, @@ -63,6 +64,18 @@ def z_score_contraction( By default, select all keys. variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. figsize : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 16 @@ -90,6 +103,21 @@ def z_score_contraction( If there is a deviation from the expected shapes of ``estimates`` and ``targets``. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + updated_data = compute_test_quantities( + targets=targets, + estimates=estimates, + variable_keys=variable_keys, + variable_names=variable_names, + test_quantities=test_quantities, + ) + variable_names = updated_data["variable_names"] + variable_keys = updated_data["variable_keys"] + estimates = updated_data["estimates"] + targets = updated_data["targets"] + + # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, From 2207e944bb72a334cd24fc0284e000690b9da44f Mon Sep 17 00:00:00 2001 From: Svenja Jedhoff Date: Wed, 15 Oct 2025 14:46:55 +0200 Subject: [PATCH 2/6] Fix Code Style --- bayesflow/diagnostics/metrics/posterior_contraction.py | 1 - bayesflow/diagnostics/plots/recovery.py | 2 +- bayesflow/diagnostics/plots/z_score_contraction.py | 2 +- tests/test_links/test_links.py | 6 +++--- .../test_point_inference_network.py | 6 +++--- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/bayesflow/diagnostics/metrics/posterior_contraction.py b/bayesflow/diagnostics/metrics/posterior_contraction.py index 1004bda40..94749c7bc 100644 --- a/bayesflow/diagnostics/metrics/posterior_contraction.py +++ b/bayesflow/diagnostics/metrics/posterior_contraction.py @@ -77,7 +77,6 @@ def posterior_contraction( estimates = updated_data["estimates"] targets = updated_data["targets"] - samples = dicts_to_arrays( estimates=estimates, targets=targets, diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 864681fac..a98d3b12c 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -7,6 +7,7 @@ from bayesflow.utils.numpy_utils import credible_interval from bayesflow.utils.dict_utils import compute_test_quantities + def recovery( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, @@ -131,7 +132,6 @@ def recovery( estimates = updated_data["estimates"] targets = updated_data["targets"] - # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, diff --git a/bayesflow/diagnostics/plots/z_score_contraction.py b/bayesflow/diagnostics/plots/z_score_contraction.py index f80116f62..190a0a608 100644 --- a/bayesflow/diagnostics/plots/z_score_contraction.py +++ b/bayesflow/diagnostics/plots/z_score_contraction.py @@ -6,6 +6,7 @@ from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots from bayesflow.utils.dict_utils import compute_test_quantities + def z_score_contraction( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, @@ -117,7 +118,6 @@ def z_score_contraction( estimates = updated_data["estimates"] targets = updated_data["targets"] - # Gather plot data and metadata into a dictionary plot_data = prepare_plot_data( estimates=estimates, diff --git a/tests/test_links/test_links.py b/tests/test_links/test_links.py index 9285f689b..396bc0bfd 100644 --- a/tests/test_links/test_links.py +++ b/tests/test_links/test_links.py @@ -23,9 +23,9 @@ def check_ordering(output, axis): assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." for i in range(output.ndim): if i != axis % output.ndim: - assert not np.all(np.diff(output, axis=i) > 0), ( - f"is ordered along axis which is not meant to be ordered: {i}." - ) + assert not np.all( + np.diff(output, axis=i) > 0 + ), f"is ordered along axis which is not meant to be ordered: {i}." @pytest.mark.parametrize("axis", [0, 1, 2]) diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 38ba8ea4e..8992e923a 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random for key_outer in out1.keys(): for key_inner in out1[key_outer].keys(): - assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), ( - "Output of original and loaded model differs significantly." - ) + assert keras.ops.all( + keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner]) + ), "Output of original and loaded model differs significantly." def test_copy_unequal(point_inference_network, random_samples, random_conditions): From d21d058317b4aff2d5c1de4e5e34463bd71ccb8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Wed, 15 Oct 2025 14:56:07 +0200 Subject: [PATCH 3/6] run ruff again --- tests/test_links/test_links.py | 6 +++--- .../test_point_inference_network.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_links/test_links.py b/tests/test_links/test_links.py index 396bc0bfd..9285f689b 100644 --- a/tests/test_links/test_links.py +++ b/tests/test_links/test_links.py @@ -23,9 +23,9 @@ def check_ordering(output, axis): assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." for i in range(output.ndim): if i != axis % output.ndim: - assert not np.all( - np.diff(output, axis=i) > 0 - ), f"is ordered along axis which is not meant to be ordered: {i}." + assert not np.all(np.diff(output, axis=i) > 0), ( + f"is ordered along axis which is not meant to be ordered: {i}." + ) @pytest.mark.parametrize("axis", [0, 1, 2]) diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 8992e923a..38ba8ea4e 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random for key_outer in out1.keys(): for key_inner in out1[key_outer].keys(): - assert keras.ops.all( - keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner]) - ), "Output of original and loaded model differs significantly." + assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), ( + "Output of original and loaded model differs significantly." + ) def test_copy_unequal(point_inference_network, random_samples, random_conditions): From bb3e4d150a06215c3d689874c7151f0231673e85 Mon Sep 17 00:00:00 2001 From: Svenja Jedhoff Date: Wed, 15 Oct 2025 16:33:01 +0200 Subject: [PATCH 4/6] Adding tests --- .../test_diagnostics_metrics.py | 37 +++++++++++++++++ .../test_diagnostics_plots.py | 40 +++++++++++++++++++ tests/test_links/test_links.py | 6 +-- .../test_point_inference_network.py | 6 +-- 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index 5945412c7..c2b5dba5c 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -35,6 +35,13 @@ def test_metric_calibration_error(random_estimates, random_targets, var_names): assert out["values"].shape == (random_estimates["sigma"].shape[-1],) assert out["variable_names"] == ["sigma"] + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets, test_quantities=test_quantities) + assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates) def test_posterior_contraction(random_estimates, random_targets): # basic functionality: automatic variable names @@ -47,6 +54,16 @@ def test_posterior_contraction(random_estimates, random_targets): out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets, aggregation=None) assert out["values"].shape == (random_estimates["sigma"].shape[0], num_variables(random_estimates)) + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.metrics.posterior_contraction( + random_estimates, random_targets, test_quantities=test_quantities + ) + assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates) + def test_root_mean_squared_error(random_estimates, random_targets): # basic functionality: automatic variable names @@ -56,6 +73,16 @@ def test_root_mean_squared_error(random_estimates, random_targets): assert out["metric_name"] == "NRMSE" assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.metrics.root_mean_squared_error( + random_estimates, random_targets, test_quantities=test_quantities + ) + assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates) + def test_classifier_two_sample_test(random_samples_a, random_samples_b): metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_a) @@ -95,6 +122,16 @@ def test_calibration_log_gamma(random_estimates, random_targets): assert out["metric_name"] == "Log Gamma" assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.metrics.calibration_log_gamma( + random_estimates, random_targets, test_quantities=test_quantities + ) + assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates) + def test_calibration_log_gamma_end_to_end(): # This is a function test for simulation-based calibration. diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index 5d4758558..f825c36cc 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -85,6 +85,16 @@ def test_calibration_histogram(random_estimates, random_targets): assert len(out.axes) == num_variables(random_estimates) assert out.axes[0].title._text == "beta_0" + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets, test_quantities=test_quantities) + assert len(out.axes) == len(test_quantities) + num_variables(random_estimates) + assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$" + assert out.axes[-1].title._text == r"sigma" + def test_loss(history): out = bf.diagnostics.loss(history) @@ -102,6 +112,16 @@ def test_recovery_bounds(random_estimates, random_targets): assert len(out.axes) == num_variables(random_estimates) assert out.axes[2].title._text == "sigma" + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets, test_quantities=test_quantities) + assert len(out.axes) == len(test_quantities) + num_variables(random_estimates) + assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$" + assert out.axes[-1].title._text == r"sigma" + def test_recovery_symmetric(random_estimates, random_targets): # basic functionality: automatic variable names @@ -127,6 +147,16 @@ def test_z_score_contraction(random_estimates, random_targets): assert len(out.axes) == num_variables(random_estimates) assert out.axes[1].title._text == "beta_1" + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets, test_quantities=test_quantities) + assert len(out.axes) == len(test_quantities) + num_variables(random_estimates) + assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$" + assert out.axes[-1].title._text == r"sigma" + def test_pairs_samples(random_priors): out = bf.diagnostics.plots.pairs_samples( @@ -291,6 +321,16 @@ def test_coverage(random_estimates, random_targets): assert out.axes[0].get_xlabel() == "Central interval width" assert out.axes[0].get_ylabel() == "Empirical coverage" + # test quantities + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.plots.coverage(random_estimates, random_targets, test_quantities=test_quantities) + assert len(out.axes) == len(test_quantities) + num_variables(random_estimates) + assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$" + assert out.axes[-1].title._text == r"sigma" + def test_coverage_diff(random_estimates, random_targets): # basic functionality: automatic variable names diff --git a/tests/test_links/test_links.py b/tests/test_links/test_links.py index 9285f689b..396bc0bfd 100644 --- a/tests/test_links/test_links.py +++ b/tests/test_links/test_links.py @@ -23,9 +23,9 @@ def check_ordering(output, axis): assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." for i in range(output.ndim): if i != axis % output.ndim: - assert not np.all(np.diff(output, axis=i) > 0), ( - f"is ordered along axis which is not meant to be ordered: {i}." - ) + assert not np.all( + np.diff(output, axis=i) > 0 + ), f"is ordered along axis which is not meant to be ordered: {i}." @pytest.mark.parametrize("axis", [0, 1, 2]) diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 38ba8ea4e..8992e923a 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random for key_outer in out1.keys(): for key_inner in out1[key_outer].keys(): - assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), ( - "Output of original and loaded model differs significantly." - ) + assert keras.ops.all( + keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner]) + ), "Output of original and loaded model differs significantly." def test_copy_unequal(point_inference_network, random_samples, random_conditions): From a731cb525b216995145d1116d49370c53fdb017e Mon Sep 17 00:00:00 2001 From: Svenja Jedhoff Date: Wed, 15 Oct 2025 16:35:21 +0200 Subject: [PATCH 5/6] adding one stupid blank line --- tests/test_diagnostics/test_diagnostics_metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index c2b5dba5c..daad874d0 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -43,6 +43,7 @@ def test_metric_calibration_error(random_estimates, random_targets, var_names): out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets, test_quantities=test_quantities) assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates) + def test_posterior_contraction(random_estimates, random_targets): # basic functionality: automatic variable names out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets) From 80481bed014781211968116a7d5566b966158de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Wed, 15 Oct 2025 16:37:38 +0200 Subject: [PATCH 6/6] again run ruff --- tests/test_links/test_links.py | 6 +++--- .../test_point_inference_network.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_links/test_links.py b/tests/test_links/test_links.py index 396bc0bfd..9285f689b 100644 --- a/tests/test_links/test_links.py +++ b/tests/test_links/test_links.py @@ -23,9 +23,9 @@ def check_ordering(output, axis): assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." for i in range(output.ndim): if i != axis % output.ndim: - assert not np.all( - np.diff(output, axis=i) > 0 - ), f"is ordered along axis which is not meant to be ordered: {i}." + assert not np.all(np.diff(output, axis=i) > 0), ( + f"is ordered along axis which is not meant to be ordered: {i}." + ) @pytest.mark.parametrize("axis", [0, 1, 2]) diff --git a/tests/test_networks/test_point_inference_network/test_point_inference_network.py b/tests/test_networks/test_point_inference_network/test_point_inference_network.py index 8992e923a..38ba8ea4e 100644 --- a/tests/test_networks/test_point_inference_network/test_point_inference_network.py +++ b/tests/test_networks/test_point_inference_network/test_point_inference_network.py @@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random for key_outer in out1.keys(): for key_inner in out1[key_outer].keys(): - assert keras.ops.all( - keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner]) - ), "Output of original and loaded model differs significantly." + assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), ( + "Output of original and loaded model differs significantly." + ) def test_copy_unequal(point_inference_network, random_samples, random_conditions):