|
22 | 22 | """ |
23 | 23 |
|
24 | 24 | import numpy as np |
25 | | -from keras.ops import covert_to_numpy, covert_to_tensor |
| 25 | +from keras.ops import convert_to_numpy, convert_to_tensor |
26 | 26 |
|
27 | 27 | from bayesflow.approximators import Approximator |
28 | | -from bayesflow.metrics import maximum_mean_discrepancy |
| 28 | +from bayesflow.metrics.functional import maximum_mean_discrepancy |
| 29 | +from bayesflow.types import Tensor |
29 | 30 |
|
30 | 31 |
|
31 | 32 | def compute_mmd_hypothesis_test_from_summaries( |
@@ -71,19 +72,31 @@ def compute_mmd_hypothesis_test_from_summaries( |
71 | 72 | mmd_null : np.ndarray |
72 | 73 | A distribution of MMD values under the null hypothesis. |
73 | 74 | """ |
| 75 | + observed_summaries_tensor: Tensor = convert_to_tensor(observed_summaries, dtype="float32") |
| 76 | + reference_summaries_tensor: Tensor = convert_to_tensor(reference_summaries, dtype="float32") |
| 77 | + |
74 | 78 | num_observed: int = observed_summaries.shape[0] |
75 | 79 | num_reference: int = reference_summaries.shape[0] |
76 | 80 |
|
77 | 81 | mmd_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64) |
78 | 82 |
|
79 | 83 | for i in range(num_null_samples): |
80 | | - bootstrap_idx: int = np.random.randint(0, num_reference, size=num_observed) |
| 84 | + bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed) |
81 | 85 | sampled_summaries: np.ndarray = reference_summaries[bootstrap_idx] |
82 | | - mmd_null_samples[i] = maximum_mean_discrepancy( |
83 | | - covert_to_tensor(observed_summaries), covert_to_tensor(sampled_summaries) |
| 86 | + sampled_summaries_tensor: Tensor = convert_to_tensor(sampled_summaries, dtype="float32") |
| 87 | + mmd_null_samples[i] = convert_to_numpy( |
| 88 | + maximum_mean_discrepancy( |
| 89 | + sampled_summaries_tensor, |
| 90 | + reference_summaries_tensor, |
| 91 | + ) |
84 | 92 | ) |
85 | 93 |
|
86 | | - mmd_observed: float = float(covert_to_numpy(maximum_mean_discrepancy(observed_summaries, reference_summaries))) |
| 94 | + mmd_observed_tensor: Tensor = maximum_mean_discrepancy( |
| 95 | + observed_summaries_tensor, |
| 96 | + reference_summaries_tensor, |
| 97 | + ) |
| 98 | + |
| 99 | + mmd_observed: float = float(convert_to_numpy(mmd_observed_tensor)) |
87 | 100 |
|
88 | 101 | return mmd_observed, mmd_null_samples |
89 | 102 |
|
@@ -134,8 +147,8 @@ def compute_mmd_hypothesis_test( |
134 | 147 | mmd_null : np.ndarray |
135 | 148 | A distribution of MMD values under the null hypothesis. |
136 | 149 | """ |
137 | | - observed_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(observed_data))) |
138 | | - reference_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(reference_data))) |
| 150 | + observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(convert_to_tensor(observed_data))) |
| 151 | + reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(convert_to_tensor(reference_data))) |
139 | 152 |
|
140 | 153 | mmd_observed, mmd_null = compute_mmd_hypothesis_test_from_summaries( |
141 | 154 | observed_summaries, reference_summaries, num_null_samples=num_null_samples |
|
0 commit comments