|
43 | 43 | """ |
44 | 44 |
|
45 | 45 | import numpy as np |
| 46 | +from keras.ops import covert_to_numpy, covert_to_tensor |
46 | 47 |
|
47 | 48 | from bayesflow.approximators import Approximator |
48 | 49 | from bayesflow.metrics import maximum_mean_discrepancy |
49 | 50 |
|
50 | 51 |
|
51 | | -# TODO: maximum_mean_discrepancy expects bayesflow.types.Tensor instead of np.ndarray as input and returns |
52 | | -# bayesflow.types.Tensor instead of float |
53 | 52 | def mmd_hypothesis_test_from_summaries( |
54 | 53 | observed_summaries: np.ndarray, |
55 | 54 | reference_summaries: np.ndarray, |
@@ -82,14 +81,15 @@ def mmd_hypothesis_test_from_summaries( |
82 | 81 | for i in range(num_null_samples): |
83 | 82 | bootstrap_idx: int = np.random.randint(0, num_reference, size=num_observed) |
84 | 83 | sampled_summaries: np.ndarray = reference_summaries[bootstrap_idx] |
85 | | - mmd_null_samples[i] = maximum_mean_discrepancy(observed_summaries, sampled_summaries) |
| 84 | + mmd_null_samples[i] = maximum_mean_discrepancy( |
| 85 | + covert_to_tensor(observed_summaries), covert_to_tensor(sampled_summaries) |
| 86 | + ) |
86 | 87 |
|
87 | | - mmd_observed: float = maximum_mean_discrepancy(observed_summaries, reference_summaries) |
| 88 | + mmd_observed: float = float(covert_to_numpy(maximum_mean_discrepancy(observed_summaries, reference_summaries))) |
88 | 89 |
|
89 | 90 | return mmd_observed, mmd_null_samples |
90 | 91 |
|
91 | 92 |
|
92 | | -# TODO: approximator.summary_network takes and returns bayesflow.types.Tensor |
93 | 93 | def mmd_hypothesis_test( |
94 | 94 | observed_data: np.ndarray, |
95 | 95 | reference_data: np.ndarray, |
@@ -117,8 +117,8 @@ def mmd_hypothesis_test( |
117 | 117 | mmd_null : np.ndarray |
118 | 118 | A distribution of MMD values under the null hypothesis. |
119 | 119 | """ |
120 | | - observed_summaries: np.ndarray = approximator.summary_network(observed_data) |
121 | | - reference_summaries: np.ndarray = approximator.summary_network(reference_data) |
| 120 | + observed_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(observed_data))) |
| 121 | + reference_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(reference_data))) |
122 | 122 |
|
123 | 123 | mmd_observed, mmd_null = mmd_hypothesis_test_from_summaries( |
124 | 124 | observed_summaries, reference_summaries, num_null_samples=num_null_samples |
|
0 commit comments