Skip to content

Commit d795788

Browse files
committed
Add test and fix weird keras behavior
1 parent ee201d5 commit d795788

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

bayesflow/diagnostics/metrics/classifier_two_sample_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ def classifier_two_sample_test(
1212
estimates: np.ndarray,
1313
targets: np.ndarray,
1414
metric: str = "accuracy",
15-
patience: int = 10,
15+
patience: int = 5,
1616
max_epochs: int = 1000,
17-
batch_size: int = 64,
17+
batch_size: int = 128,
1818
return_metric_only: bool = True,
1919
validation_split: float = 0.5,
2020
standardize: bool = True,
21-
mlp_widths: Sequence = (256, 256),
21+
mlp_widths: Sequence = (64, 64),
2222
**kwargs,
2323
) -> float | Mapping[str, Any]:
2424
"""
@@ -95,6 +95,11 @@ def classifier_two_sample_test(
9595
data = np.r_[estimates, targets]
9696
labels = np.r_[np.zeros((estimates.shape[0],)), np.ones((targets.shape[0],))]
9797

98+
# Important: needed, since keras does not shuffle before selecting validation split
99+
shuffle_idx = np.random.permutation(data.shape[0])
100+
data = data[shuffle_idx]
101+
labels = labels[shuffle_idx]
102+
98103
# Create and train classifier with optional stopping
99104
classifier = keras.Sequential(
100105
[MLP(widths=mlp_widths, **kwargs.get("mlp_kwargs", {})), keras.layers.Dense(1, activation="sigmoid")]

tests/test_diagnostics/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ def var_names():
1010

1111
@pytest.fixture()
1212
def random_samples_a():
13-
return np.random.normal(loc=0, scale=1, size=(1000, 8))
13+
return np.random.normal(loc=0, scale=1, size=(5000, 8))
1414

1515

1616
@pytest.fixture()
1717
def random_samples_b():
18-
return np.random.normal(loc=0, scale=3, size=(1000, 8))
18+
return np.random.normal(loc=0, scale=3, size=(5000, 8))
1919

2020

2121
@pytest.fixture()

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def test_root_mean_squared_error(random_estimates, random_targets):
5252

5353
def test_classifier_two_sample_test(random_samples_a, random_samples_b):
5454
metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_a)
55-
assert 0.6 > metric > 0.4
55+
assert 0.55 > metric > 0.45
5656

5757
metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_b)
58-
assert metric > 0.6
58+
assert metric > 0.55
5959

6060

6161
def test_expected_calibration_error(pred_models, true_models, model_names):

0 commit comments

Comments
 (0)