diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 420b9a9a2..5a4f9702f 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -3,6 +3,7 @@ import numpy as np import keras +import warnings from bayesflow.adapters import Adapter from bayesflow.networks import InferenceNetwork, SummaryNetwork @@ -539,7 +540,7 @@ def _sample( batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample) ) - def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: + def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: """ Computes the learned summary statistics of given summary variables. @@ -570,6 +571,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: return summaries + def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: + """ + .. deprecated:: 2.0.4 + `summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead. + """ + warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning) + return self.summarize(data=data, **kwargs) + def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: """ Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 4f23bcfcb..608158d2b 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -2,6 +2,7 @@ import keras import numpy as np +import warnings from bayesflow.adapters import Adapter from bayesflow.datasets import OnlineDataset @@ -404,7 +405,7 @@ def predict( return keras.ops.convert_to_numpy(keras.ops.softmax(output) if probs else output) - def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: + def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: """ Computes the learned summary statistics of given summary variables. @@ -435,6 +436,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: return summaries + def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray: + """ + .. deprecated:: 2.0.4 + `summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead. + """ + warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning) + return self.summarize(data=data, **kwargs) + def _compute_logits(self, classifier_conditions: Tensor) -> Tensor: """Helper to compute projected logits from the classifier network.""" logits = self.classifier_network(classifier_conditions) diff --git a/bayesflow/diagnostics/metrics/model_misspecification.py b/bayesflow/diagnostics/metrics/model_misspecification.py index c698d4eb2..802b2d336 100644 --- a/bayesflow/diagnostics/metrics/model_misspecification.py +++ b/bayesflow/diagnostics/metrics/model_misspecification.py @@ -142,8 +142,8 @@ def summary_space_comparison( "statistics, or want to compare raw data and not summary statistics, please use the " f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays." ) - observed_summaries = convert_to_numpy(approximator.summaries(observed_data)) - reference_summaries = convert_to_numpy(approximator.summaries(reference_data)) + observed_summaries = convert_to_numpy(approximator.summarize(observed_data)) + reference_summaries = convert_to_numpy(approximator.summarize(reference_data)) distance_observed, distance_null = bootstrap_comparison( observed_samples=observed_summaries, diff --git a/tests/test_approximators/test_summaries.py b/tests/test_approximators/test_summarize.py similarity index 82% rename from tests/test_approximators/test_summaries.py rename to tests/test_approximators/test_summarize.py index 7962ddaab..04509cc3a 100644 --- a/tests/test_approximators/test_summaries.py +++ b/tests/test_approximators/test_summarize.py @@ -5,7 +5,7 @@ def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch): monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) - summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + summaries = approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))}) assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1)) @@ -13,11 +13,11 @@ def test_no_summary_network(approximator_with_summaries, monkeypatch): monkeypatch.setattr(approximator_with_summaries, "summary_network", None) with pytest.raises(ValueError): - approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))}) + approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))}) def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch): monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network) with pytest.raises(ValueError): - approximator_with_summaries.summaries({}) + approximator_with_summaries.summarize({})