Skip to content

Commit 2fdc5c3

Browse files
committed
add rudimentary test for loss plot
1 parent 722c861 commit 2fdc5c3

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

tests/test_diagnostics/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import keras
12
import numpy as np
23
import pytest
34
from bayesflow.utils.numpy_utils import softmax
@@ -61,3 +62,19 @@ def pred_models(true_models):
6162
pred_models = np.random.normal(loc=true_models)
6263
pred_models = softmax(pred_models, axis=-1)
6364
return pred_models
65+
66+
67+
@pytest.fixture()
68+
def history():
69+
h = keras.callbacks.History()
70+
71+
step = np.linspace(0, 1, 10_000)
72+
train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
73+
validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
74+
75+
h.history = {
76+
"loss": train_loss.tolist(),
77+
"val_loss": validation_loss.tolist(),
78+
}
79+
80+
return h

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def test_calibration_histogram(random_estimates, random_targets):
4545
assert out.axes[0].title._text == "beta_0"
4646

4747

48+
def test_loss(history):
49+
import matplotlib.pyplot as plt
50+
51+
out = bf.diagnostics.loss(history)
52+
plt.show()
53+
assert len(out.axes) == 1
54+
assert out.axes[0].title._text == "Loss Trajectory"
55+
56+
4857
def test_recovery(random_estimates, random_targets):
4958
# basic functionality: automatic variable names
5059
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)

0 commit comments

Comments
 (0)