Skip to content

Commit dbdbe2e

Browse files
committed
allow approximator argument to also be of bayesflow.network.SummaryNetwork in addition to bayesflow.approximators.ContinuousApproximator
1 parent f67da8a commit dbdbe2e

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from keras.ops import convert_to_numpy, convert_to_tensor
2626

2727
from bayesflow.approximators import ContinuousApproximator
28+
from bayesflow.networks import SummaryNetwork
2829
from bayesflow.metrics.functional import maximum_mean_discrepancy
2930
from bayesflow.types import Tensor
3031

@@ -124,7 +125,7 @@ def compute_mmd_hypothesis_test_from_summaries(
124125
def compute_mmd_hypothesis_test(
125126
observed_data: np.ndarray,
126127
reference_data: np.ndarray,
127-
approximator: ContinuousApproximator,
128+
approximator: ContinuousApproximator | SummaryNetwork,
128129
num_null_samples: int = 100,
129130
) -> tuple[float, np.ndarray]:
130131
"""Computes the Maximum Mean Discrepancy (MMD) between observed and reference data and generates a distribution of
@@ -155,8 +156,8 @@ def compute_mmd_hypothesis_test(
155156
Observed data, shape (num_observed, ...).
156157
reference_data : np.ndarray
157158
Reference data, shape (num_reference, ...).
158-
approximator : ContinuousApproximator
159-
An instance of the ContinuousApproximator class used to obtain summary statistics from data.
159+
approximator : ContinuousApproximator or SummaryNetwork
160+
An instance of the ContinuousApproximator or SummaryNetwork class used to extract summary statistics from data.
160161
num_null_samples : int
161162
Number of null samples to generate for hypothesis testing. Default is 100.
162163
@@ -178,14 +179,22 @@ def compute_mmd_hypothesis_test(
178179
f"but got {observed_data.shape[1:]} != {reference_data.shape[1:]}."
179180
)
180181

181-
if approximator.summary_network is not None:
182+
if isinstance(approximator, ContinuousApproximator):
183+
if approximator.summary_network is not None:
184+
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
185+
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
186+
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
187+
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
188+
else:
189+
observed_summaries: np.ndarray = observed_data
190+
reference_summaries: np.ndarray = reference_data
191+
elif isinstance(approximator, SummaryNetwork):
182192
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
183193
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
184-
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
185-
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
194+
observed_summaries: np.ndarray = convert_to_numpy(approximator(observed_data_tensor))
195+
reference_summaries: np.ndarray = convert_to_numpy(approximator(reference_data_tensor))
186196
else:
187-
observed_summaries: np.ndarray = observed_data
188-
reference_summaries: np.ndarray = reference_data
197+
raise ValueError("The approximator must be an instance of ContinuousApproximator or SummaryNetwork.")
189198

190199
mmd_observed, mmd_null = compute_mmd_hypothesis_test_from_summaries(
191200
observed_summaries, reference_summaries, num_null_samples=num_null_samples

0 commit comments

Comments
 (0)