Skip to content

Commit ef9dd00

Browse files
committed
handle case for when ContinuousApproximator.summary_network is None
1 parent c8f54c0 commit ef9dd00

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,15 @@ def compute_mmd_hypothesis_test(
147147
mmd_null : np.ndarray
148148
A distribution of MMD values under the null hypothesis.
149149
"""
150-
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
151-
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
152150

153-
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
154-
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
151+
if approximator.summary_network is not None:
152+
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
153+
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
154+
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
155+
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
156+
else:
157+
observed_summaries: np.ndarray = convert_to_numpy(observed_data_tensor)
158+
reference_summaries: np.ndarray = convert_to_numpy(reference_data_tensor)
155159

156160
mmd_observed, mmd_null = compute_mmd_hypothesis_test_from_summaries(
157161
observed_summaries, reference_summaries, num_null_samples=num_null_samples

0 commit comments

Comments
 (0)