Skip to content

Commit 4278ea0

Browse files
committed
remove mmd from two moons test
1 parent c5a5d6b commit 4278ea0

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tests/test_two_moons/test_two_moons.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@ def test_compile(approximator, random_samples, jit_compile):
1313

1414

1515
def test_fit(approximator, train_dataset, validation_dataset, batch_size):
16-
from bayesflow.metrics import MaximumMeanDiscrepancy
17-
from bayesflow.networks import PointInferenceNetwork
18-
1916
inference_metrics = []
20-
if not isinstance(approximator.inference_network, PointInferenceNetwork):
21-
inference_metrics += [MaximumMeanDiscrepancy()]
17+
# if not isinstance(approximator.inference_network, PointInferenceNetwork):
18+
# inference_metrics += [MaximumMeanDiscrepancy()]
2219
approximator.compile(inference_metrics=inference_metrics)
2320

2421
mock_data = train_dataset[0]
@@ -41,8 +38,8 @@ def test_fit(approximator, train_dataset, validation_dataset, batch_size):
4138

4239
# test that metrics are improving
4340
metric_names = ["loss"]
44-
if not isinstance(approximator.inference_network, PointInferenceNetwork):
45-
metric_names += ["maximum_mean_discrepancy/inference_maximum_mean_discrepancy"]
41+
# if not isinstance(approximator.inference_network, PointInferenceNetwork):
42+
# metric_names += ["maximum_mean_discrepancy/inference_maximum_mean_discrepancy"]
4643
for metric in metric_names:
4744
assert metric in untrained_metrics
4845
assert metric in trained_metrics

0 commit comments

Comments
 (0)