diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index bf4e263a0..834521d4b 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -400,6 +400,39 @@ def _sample( **filter_kwargs(kwargs, self.inference_network.sample), ) + def summaries(self, data: Mapping[str, np.ndarray], **kwargs): + """ + Computes the summaries of given data. + + The `data` dictionary is preprocessed using the `adapter` and passed through the summary network. + + Parameters + ---------- + data : Mapping[str, np.ndarray] + Dictionary of data as NumPy arrays. + **kwargs : dict + Additional keyword arguments for the adapter and the summary network. + + Returns + ------- + summaries : np.ndarray + Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` + + Raises + ------ + ValueError + If the approximator does not have a summary network, or the adapter does not produce the output required + by the summary network. + """ + if self.summary_network is None: + raise ValueError("A summary network is required to compute summeries.") + data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs) + if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None: + raise ValueError("Summary variables are required to compute summaries.") + summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"]) + summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call)) + return summaries + def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]: """ Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 1b9d198ff..028e8837a 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens output = self.logits_projector(output) return output + + def summaries(self, data: Mapping[str, np.ndarray], **kwargs): + """ + Computes the summaries of given data. + + The `data` dictionary is preprocessed using the `adapter` and passed through the summary network. + + Parameters + ---------- + data : Mapping[str, np.ndarray] + Dictionary of data as NumPy arrays. + **kwargs : dict + Additional keyword arguments for the adapter and the summary network. + + Returns + ------- + summaries : np.ndarray + Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))` + + Raises + ------ + ValueError + If the approximator does not have a summary network, or the adapter does not produce the output required + by the summary network. + """ + if self.summary_network is None: + raise ValueError("A summary network is required to compute summaries.") + data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs) + if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None: + raise ValueError("Summary variables are required to compute summaries.") + summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"]) + summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call)) + return summaries diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py index 1e13e11f2..87823c754 100644 --- a/bayesflow/diagnostics/__init__.py +++ b/bayesflow/diagnostics/__init__.py @@ -1,8 +1,13 @@ -""" +r""" A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s. """ -from .metrics import root_mean_squared_error, calibration_error, posterior_contraction +from .metrics import ( + bootstrap_comparison, + calibration_error, + posterior_contraction, + summary_space_comparison, +) from .plots import ( calibration_ecdf, diff --git a/bayesflow/diagnostics/metrics/__init__.py b/bayesflow/diagnostics/metrics/__init__.py index ceeca4cc4..3e3496cda 100644 --- a/bayesflow/diagnostics/metrics/__init__.py +++ b/bayesflow/diagnostics/metrics/__init__.py @@ -3,3 +3,4 @@ from .root_mean_squared_error import root_mean_squared_error from .expected_calibration_error import expected_calibration_error from .classifier_two_sample_test import classifier_two_sample_test +from .model_misspecification import bootstrap_comparison, summary_space_comparison diff --git a/bayesflow/diagnostics/metrics/model_misspecification.py b/bayesflow/diagnostics/metrics/model_misspecification.py new file mode 100644 index 000000000..c698d4eb2 --- /dev/null +++ b/bayesflow/diagnostics/metrics/model_misspecification.py @@ -0,0 +1,155 @@ +""" +This module provides functions for computing distances between observation samples and reference samples with distance +distributions within the reference samples for hypothesis testing. +""" + +from collections.abc import Mapping, Callable + +import numpy as np +from keras.ops import convert_to_numpy, convert_to_tensor + +from bayesflow.approximators import ContinuousApproximator +from bayesflow.metrics.functional import maximum_mean_discrepancy +from bayesflow.types import Tensor + + +def bootstrap_comparison( + observed_samples: np.ndarray, + reference_samples: np.ndarray, + comparison_fn: Callable[[Tensor, Tensor], Tensor], + num_null_samples: int = 100, +) -> tuple[float, np.ndarray]: + """Computes the distance between observed and reference samples and generates a distribution of null sample + distances by bootstrapping for hypothesis testing. + + Parameters + ---------- + observed_samples : np.ndarray) + Observed samples, shape (num_observed, ...). + reference_samples : np.ndarray + Reference samples, shape (num_reference, ...). + comparison_fn : Callable[[Tensor, Tensor], Tensor] + Function to compute the distance metric. + num_null_samples : int + Number of null samples to generate for hypothesis testing. Default is 100. + + Returns + ------- + distance_observed : float + The distance value between observed and reference samples. + distance_null : np.ndarray + A distribution of distance values under the null hypothesis. + + Raises + ------ + ValueError + - If the number of number of observed samples exceeds the number of reference samples + - If the shapes of observed and reference samples do not match on dimensions besides the first one. + """ + num_observed: int = observed_samples.shape[0] + num_reference: int = reference_samples.shape[0] + + if num_observed > num_reference: + raise ValueError( + f"Number of observed samples ({num_observed}) cannot exceed" + f"the number of reference samples ({num_reference}) for bootstrapping." + ) + if observed_samples.shape[1:] != reference_samples.shape[1:]: + raise ValueError( + f"Expected observed and reference samples to have the same shape, " + f"but got {observed_samples.shape[1:]} != {reference_samples.shape[1:]}." + ) + + observed_samples_tensor: Tensor = convert_to_tensor(observed_samples, dtype="float32") + reference_samples_tensor: Tensor = convert_to_tensor(reference_samples, dtype="float32") + + distance_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64) + for i in range(num_null_samples): + bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed) + bootstrap_samples: np.ndarray = reference_samples[bootstrap_idx] + bootstrap_samples_tensor: Tensor = convert_to_tensor(bootstrap_samples, dtype="float32") + distance_null_samples[i] = convert_to_numpy(comparison_fn(bootstrap_samples_tensor, reference_samples_tensor)) + + distance_observed_tensor: Tensor = comparison_fn( + observed_samples_tensor, + reference_samples_tensor, + ) + + distance_observed: float = float(convert_to_numpy(distance_observed_tensor)) + + return distance_observed, distance_null_samples + + +def summary_space_comparison( + observed_data: Mapping[str, np.ndarray], + reference_data: Mapping[str, np.ndarray], + approximator: ContinuousApproximator, + num_null_samples: int = 100, + comparison_fn: Callable = maximum_mean_discrepancy, + **kwargs, +) -> tuple[float, np.ndarray]: + """Computes the distance between observed and reference data in the summary space and + generates a distribution of distance values under the null hypothesis to assess model misspecification. + + By default, the Maximum Mean Discrepancy (MMD) is used as a distance function. + + [1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian + inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866. + URL: https://arxiv.org/abs/2112.08866 + + Parameters + ---------- + observed_data : dict[str, np.ndarray] + Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed + through its summary network. + reference_data : dict[str, np.ndarray] + Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed + through its summary network. + approximator : ContinuousApproximator + An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics + from the data. + num_null_samples : int, optional + Number of null samples to generate for hypothesis testing. Default is 100. + comparison_fn : Callable, optional + Distance function to compare the data in the summary space. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + distance_observed : float + The MMD value between observed and reference summaries. + distance_null : np.ndarray + A distribution of MMD values under the null hypothesis. + + Raises + ------ + ValueError + If approximator is not an instance of ContinuousApproximator or does not have a summary network. + """ + + if not isinstance(approximator, ContinuousApproximator): + raise ValueError("The approximator must be an instance of ContinuousApproximator.") + + if not hasattr(approximator, "summary_network") or approximator.summary_network is None: + comparison_fn_name = ( + "bayesflow.metrics.functional.maximum_mean_discrepancy" + if comparison_fn is maximum_mean_discrepancy + else comparison_fn.__name__ + ) + raise ValueError( + "The approximator must have a summary network. If you have manually crafted summary " + "statistics, or want to compare raw data and not summary statistics, please use the " + f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays." + ) + observed_summaries = convert_to_numpy(approximator.summaries(observed_data)) + reference_summaries = convert_to_numpy(approximator.summaries(reference_data)) + + distance_observed, distance_null = bootstrap_comparison( + observed_samples=observed_summaries, + reference_samples=reference_summaries, + comparison_fn=comparison_fn, + num_null_samples=num_null_samples, + ) + + return distance_observed, distance_null diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index 125371a52..227e70ff1 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator): num_batches = 2 data = simulator.sample((num_batches * batch_size,)) return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches) + + +@pytest.fixture() +def mean_std_summary_network(): + from tests.utils import MeanStdSummaryNetwork + + return MeanStdSummaryNetwork() + + +@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"]) +def approximator_with_summaries(request): + from bayesflow.adapters import Adapter + + adapter = Adapter() + match request.param: + case "continuous_approximator": + from bayesflow.approximators import ContinuousApproximator + + return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None) + case "point_approximator": + from bayesflow.approximators import PointApproximator + + return PointApproximator(adapter=adapter, inference_network=None, summary_network=None) + case "model_comparison_approximator": + from bayesflow.approximators import ModelComparisonApproximator + + return ModelComparisonApproximator( + num_models=2, classifier_network=None, adapter=adapter, summary_network=None + ) + case _: + raise ValueError("Invalid param for approximator class.") diff --git a/tests/test_approximators/test_summaries.py b/tests/test_approximators/test_summaries.py new file mode 100644 index 000000000..7962ddaab --- /dev/null +++ b/tests/test_approximators/test_summaries.py @@ -0,0 +1,23 @@ +import pytest +from tests.utils import assert_allclose +import keras + + +def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) + summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1)) + + +def test_no_summary_network(approximator_with_summaries, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", None) + + with pytest.raises(ValueError): + approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + + +def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch): + monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) + + with pytest.raises(ValueError): + approximator_with_summaries.summaries({}) diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 8e77d6729..dc859d2d4 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -78,3 +78,17 @@ def history(): } return h + + +@pytest.fixture() +def adapter(): + from bayesflow.adapters import Adapter + + return Adapter.create_default("parameters").rename("observables", "summary_variables") + + +@pytest.fixture() +def summary_network(): + from tests.utils import MeanStdSummaryNetwork + + return MeanStdSummaryNetwork() diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index 4fb0945b3..3a2c711bc 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -1,6 +1,9 @@ -import bayesflow as bf +import numpy as np +import keras import pytest +import bayesflow as bf + def num_variables(x: dict): return sum(arr.shape[-1] for arr in x.values()) @@ -79,3 +82,288 @@ def test_expected_calibration_error(pred_models, true_models, model_names): with pytest.raises(Exception): out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose) + + +def test_bootstrap_comparison_shapes(): + """Test the bootstrap_comparison output shapes.""" + observed_samples = np.random.rand(10, 5) + reference_samples = np.random.rand(100, 5) + num_null_samples = 50 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert isinstance(distance_observed, float) + assert isinstance(distance_null, np.ndarray) + assert distance_null.shape == (num_null_samples,) + + +def test_bootstrap_comparison_same_distribution(): + """Test bootstrap_comparison on same distributions.""" + observed_samples = np.random.normal(loc=0.5, scale=0.1, size=(10, 5)) + reference_samples = observed_samples.copy() + num_null_samples = 5 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert distance_observed <= np.quantile(distance_null, 0.99) + + +def test_bootstrap_comparison_different_distributions(): + """Test bootstrap_comparison on different distributions.""" + observed_samples = np.random.normal(loc=-5, scale=0.1, size=(10, 5)) + reference_samples = np.random.normal(loc=5, scale=0.1, size=(100, 5)) + num_null_samples = 50 + + distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + assert distance_observed >= np.quantile(distance_null, 0.68) + + +def test_bootstrap_comparison_mismatched_shapes(): + """Test bootstrap_comparison raises ValueError for mismatched shapes.""" + observed_samples = np.random.rand(10, 5) + reference_samples = np.random.rand(20, 4) + num_null_samples = 10 + + with pytest.raises(ValueError): + bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + +def test_bootstrap_comparison_num_observed_exceeds_num_reference(): + """Test bootstrap_comparison raises ValueError when number of observed samples exceeds the number of reference + samples.""" + observed_samples = np.random.rand(100, 5) + reference_samples = np.random.rand(20, 5) + num_null_samples = 50 + + with pytest.raises(ValueError): + bf.diagnostics.metrics.bootstrap_comparison( + observed_samples, + reference_samples, + lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)), + num_null_samples, + ) + + +def test_mmd_comparison_from_summaries_shapes(): + """Test the mmd_comparison_from_summaries output shapes.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.rand(100, 5) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert isinstance(mmd_observed, float) + assert isinstance(mmd_null, np.ndarray) + assert mmd_null.shape == (num_null_samples,) + + +def test_mmd_comparison_from_summaries_positive(): + """Test MMD output values of mmd_comparison_from_summaries are positive.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.rand(100, 5) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed >= 0 + assert np.all(mmd_null >= 0) + + +def test_mmd_comparison_from_summaries_same_distribution(): + """Test mmd_comparison_from_summaries on same distributions.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = observed_summaries.copy() + num_null_samples = 5 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed <= np.quantile(mmd_null, 0.99) + + +def test_mmd_comparison_from_summaries_different_distributions(): + """Test mmd_comparison_from_summaries on different distributions.""" + observed_summaries = np.random.rand(10, 5) + reference_summaries = np.random.normal(loc=0.5, scale=0.1, size=(100, 5)) + num_null_samples = 50 + + mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison( + observed_summaries, + reference_summaries, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + num_null_samples=num_null_samples, + ) + + assert mmd_observed >= np.quantile(mmd_null, 0.68) + + +def test_mmd_comparison_shapes(summary_network, adapter): + """Test the mmd_comparison output shapes.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert isinstance(mmd_observed, float) + assert isinstance(mmd_null, np.ndarray) + assert mmd_null.shape == (num_null_samples,) + + +def test_mmd_comparison_positive(summary_network, adapter): + """Test MMD output values of mmd_comparison are positive.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed >= 0 + assert np.all(mmd_null >= 0) + + +def test_mmd_comparison_same_distribution(summary_network, adapter): + """Test mmd_comparison on same distributions.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = observed_data + num_null_samples = 5 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed <= np.quantile(mmd_null, 0.99) + + +def test_mmd_comparison_different_distributions(summary_network, adapter): + """Test mmd_comparison on different distributions.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.normal(loc=0.5, scale=0.1, size=(100, 5))) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=summary_network, + ) + + mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + assert mmd_observed >= np.quantile(mmd_null, 0.68) + + +def test_mmd_comparison_no_summary_network(adapter): + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + mock_approximator = bf.approximators.ContinuousApproximator( + adapter=adapter, + inference_network=None, + summary_network=None, + ) + + with pytest.raises(ValueError): + bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) + + +def test_mmd_comparison_approximator_incorrect_instance(): + """Test mmd_comparison raises ValueError for incorrect approximator instance.""" + observed_data = dict(observables=np.random.rand(10, 5)) + reference_data = dict(observables=np.random.rand(100, 5)) + num_null_samples = 50 + + class IncorrectApproximator: + pass + + mock_approximator = IncorrectApproximator() + + with pytest.raises(ValueError): + bf.diagnostics.metrics.summary_space_comparison( + observed_data=observed_data, + reference_data=reference_data, + approximator=mock_approximator, + num_null_samples=num_null_samples, + comparison_fn=bf.metrics.functional.maximum_mean_discrepancy, + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index f36b02bbd..9c2affc22 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -2,4 +2,5 @@ from .callbacks import * from .check_combinations import * from .jupyter import * +from .networks import * from .ops import * diff --git a/tests/utils/networks.py b/tests/utils/networks.py new file mode 100644 index 000000000..cf35e1463 --- /dev/null +++ b/tests/utils/networks.py @@ -0,0 +1,8 @@ +from bayesflow.networks import SummaryNetwork +import keras + + +class MeanStdSummaryNetwork(SummaryNetwork): + def call(self, x): + summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1) + return summary_outputs