2525from keras .ops import convert_to_numpy , convert_to_tensor
2626
2727from bayesflow .approximators import ContinuousApproximator
28+ from bayesflow .networks import SummaryNetwork
2829from bayesflow .metrics .functional import maximum_mean_discrepancy
2930from bayesflow .types import Tensor
3031
@@ -124,7 +125,7 @@ def compute_mmd_hypothesis_test_from_summaries(
124125def compute_mmd_hypothesis_test (
125126 observed_data : np .ndarray ,
126127 reference_data : np .ndarray ,
127- approximator : ContinuousApproximator ,
128+ approximator : ContinuousApproximator | SummaryNetwork ,
128129 num_null_samples : int = 100 ,
129130) -> tuple [float , np .ndarray ]:
130131 """Computes the Maximum Mean Discrepancy (MMD) between observed and reference data and generates a distribution of
@@ -155,8 +156,8 @@ def compute_mmd_hypothesis_test(
155156 Observed data, shape (num_observed, ...).
156157 reference_data : np.ndarray
157158 Reference data, shape (num_reference, ...).
158- approximator : ContinuousApproximator
159- An instance of the ContinuousApproximator class used to obtain summary statistics from data.
159+ approximator : ContinuousApproximator or SummaryNetwork
160+ An instance of the ContinuousApproximator or SummaryNetwork class used to extract summary statistics from data.
160161 num_null_samples : int
161162 Number of null samples to generate for hypothesis testing. Default is 100.
162163
@@ -178,14 +179,22 @@ def compute_mmd_hypothesis_test(
178179 f"but got { observed_data .shape [1 :]} != { reference_data .shape [1 :]} ."
179180 )
180181
181- if approximator .summary_network is not None :
182+ if isinstance (approximator , ContinuousApproximator ):
183+ if approximator .summary_network is not None :
184+ observed_data_tensor : Tensor = convert_to_tensor (observed_data )
185+ reference_data_tensor : Tensor = convert_to_tensor (reference_data )
186+ observed_summaries : np .ndarray = convert_to_numpy (approximator .summary_network (observed_data_tensor ))
187+ reference_summaries : np .ndarray = convert_to_numpy (approximator .summary_network (reference_data_tensor ))
188+ else :
189+ observed_summaries : np .ndarray = observed_data
190+ reference_summaries : np .ndarray = reference_data
191+ elif isinstance (approximator , SummaryNetwork ):
182192 observed_data_tensor : Tensor = convert_to_tensor (observed_data )
183193 reference_data_tensor : Tensor = convert_to_tensor (reference_data )
184- observed_summaries : np .ndarray = convert_to_numpy (approximator . summary_network (observed_data_tensor ))
185- reference_summaries : np .ndarray = convert_to_numpy (approximator . summary_network (reference_data_tensor ))
194+ observed_summaries : np .ndarray = convert_to_numpy (approximator (observed_data_tensor ))
195+ reference_summaries : np .ndarray = convert_to_numpy (approximator (reference_data_tensor ))
186196 else :
187- observed_summaries : np .ndarray = observed_data
188- reference_summaries : np .ndarray = reference_data
197+ raise ValueError ("The approximator must be an instance of ContinuousApproximator or SummaryNetwork." )
189198
190199 mmd_observed , mmd_null = compute_mmd_hypothesis_test_from_summaries (
191200 observed_summaries , reference_summaries , num_null_samples = num_null_samples
0 commit comments