Skip to content

Commit f2f863c

Browse files
committed
- update compute_mmd_hypothesis_test_from_summaries implementation
- unit test output shape of compute_mmd_hypothesis_test_from_summaries
1 parent 76f88b5 commit f2f863c

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
"""
2323

2424
import numpy as np
25-
from keras.ops import covert_to_numpy, covert_to_tensor
25+
from keras.ops import convert_to_numpy, convert_to_tensor
2626

2727
from bayesflow.approximators import Approximator
28-
from bayesflow.metrics import maximum_mean_discrepancy
28+
from bayesflow.metrics.functional import maximum_mean_discrepancy
29+
from bayesflow.types import Tensor
2930

3031

3132
def compute_mmd_hypothesis_test_from_summaries(
@@ -71,19 +72,31 @@ def compute_mmd_hypothesis_test_from_summaries(
7172
mmd_null : np.ndarray
7273
A distribution of MMD values under the null hypothesis.
7374
"""
75+
observed_summaries_tensor: Tensor = convert_to_tensor(observed_summaries, dtype="float32")
76+
reference_summaries_tensor: Tensor = convert_to_tensor(reference_summaries, dtype="float32")
77+
7478
num_observed: int = observed_summaries.shape[0]
7579
num_reference: int = reference_summaries.shape[0]
7680

7781
mmd_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64)
7882

7983
for i in range(num_null_samples):
80-
bootstrap_idx: int = np.random.randint(0, num_reference, size=num_observed)
84+
bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed)
8185
sampled_summaries: np.ndarray = reference_summaries[bootstrap_idx]
82-
mmd_null_samples[i] = maximum_mean_discrepancy(
83-
covert_to_tensor(observed_summaries), covert_to_tensor(sampled_summaries)
86+
sampled_summaries_tensor: Tensor = convert_to_tensor(sampled_summaries, dtype="float32")
87+
mmd_null_samples[i] = convert_to_numpy(
88+
maximum_mean_discrepancy(
89+
sampled_summaries_tensor,
90+
reference_summaries_tensor,
91+
)
8492
)
8593

86-
mmd_observed: float = float(covert_to_numpy(maximum_mean_discrepancy(observed_summaries, reference_summaries)))
94+
mmd_observed_tensor: Tensor = maximum_mean_discrepancy(
95+
observed_summaries_tensor,
96+
reference_summaries_tensor,
97+
)
98+
99+
mmd_observed: float = float(convert_to_numpy(mmd_observed_tensor))
87100

88101
return mmd_observed, mmd_null_samples
89102

@@ -134,8 +147,8 @@ def compute_mmd_hypothesis_test(
134147
mmd_null : np.ndarray
135148
A distribution of MMD values under the null hypothesis.
136149
"""
137-
observed_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(observed_data)))
138-
reference_summaries: np.ndarray = covert_to_numpy(approximator.summary_network(covert_to_tensor(reference_data)))
150+
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(convert_to_tensor(observed_data)))
151+
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(convert_to_tensor(reference_data)))
139152

140153
mmd_observed, mmd_null = compute_mmd_hypothesis_test_from_summaries(
141154
observed_summaries, reference_summaries, num_null_samples=num_null_samples

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import bayesflow as bf
1+
import numpy as np
22
import pytest
33

4+
import bayesflow as bf
5+
46

57
def num_variables(x: dict):
68
return sum(arr.shape[-1] for arr in x.values())
@@ -71,3 +73,23 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
7173

7274
with pytest.raises(Exception):
7375
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
76+
77+
78+
# -------------------------------------------------------------------------------------------------------------------- #
79+
# Unit tests for MMD Hypothesis Test #
80+
# -------------------------------------------------------------------------------------------------------------------- #
81+
82+
83+
def test_compute_hypothesis_test_from_summaries_shapes() -> None:
84+
"""Test the compute_hypothesis_test_from_summaries output shapes."""
85+
observed_summaries = np.random.rand(10, 5)
86+
reference_summaries = np.random.rand(100, 5)
87+
num_null_samples = 50
88+
89+
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test_from_summaries(
90+
observed_summaries, reference_summaries, num_null_samples=num_null_samples
91+
)
92+
93+
assert isinstance(mmd_observed, float)
94+
assert isinstance(mmd_null, np.ndarray)
95+
assert mmd_null.shape == (num_null_samples,)

0 commit comments

Comments
 (0)