Skip to content

Commit 48c2f8c

Browse files
committed
add type casting between np.ndarray and bf.types.Tensor with keras.ops
1 parent 0333c77 commit 48c2f8c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@
4343
"""
4444

4545
import numpy as np
46+
from keras.ops import covert_to_numpy, covert_to_tensor
4647

4748
from bayesflow.approximators import Approximator
4849
from bayesflow.metrics import maximum_mean_discrepancy
4950

5051

51-
# TODO: maximum_mean_discrepancy expects bayesflow.types.Tensor instead of np.ndarray as input and returns
52-
# bayesflow.types.Tensor instead of float
5352
def mmd_hypothesis_test_from_summaries(
5453
observed_summaries: np.ndarray,
5554
reference_summaries: np.ndarray,
@@ -82,14 +81,15 @@ def mmd_hypothesis_test_from_summaries(
8281
for i in range(num_null_samples):
8382
bootstrap_idx: int = np.random.randint(0, num_reference, size=num_observed)
8483
sampled_summaries: np.ndarray = reference_summaries[bootstrap_idx]
85-
mmd_null_samples[i] = maximum_mean_discrepancy(observed_summaries, sampled_summaries)
84+
mmd_null_samples[i] = maximum_mean_discrepancy(
85+
covert_to_tensor(observed_summaries), covert_to_tensor(sampled_summaries)
86+
)
8687

87-
mmd_observed: float = maximum_mean_discrepancy(observed_summaries, reference_summaries)
88+
mmd_observed: float = float(covert_to_numpy(maximum_mean_discrepancy(observed_summaries, reference_summaries)))
8889

8990
return mmd_observed, mmd_null_samples
9091

9192

92-
# TODO: approximator.summary_network takes and returns bayesflow.types.Tensor
9393
def mmd_hypothesis_test(
9494
observed_data: np.ndarray,
9595
reference_data: np.ndarray,
@@ -117,8 +117,8 @@ def mmd_hypothesis_test(
117117
mmd_null : np.ndarray
118118
A distribution of MMD values under the null hypothesis.
119119
"""
120-
observed_summaries: np.ndarray = approximator.summary_network(observed_data)
121-
reference_summaries: np.ndarray = approximator.summary_network(reference_data)
120+
observed_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(observed_data)))
121+
reference_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(reference_data)))
122122

123123
mmd_observed, mmd_null = mmd_hypothesis_test_from_summaries(
124124
observed_summaries, reference_summaries, num_null_samples=num_null_samples

0 commit comments

Comments
 (0)