Skip to content

Commit c6a16e3

Browse files
committed
Improve distribution tests
1 parent 2002aad commit c6a16e3

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

tests/test_distributions/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ def diagonal_student_t():
3232
return DiagonalStudentT(df=10)
3333

3434

35-
@pytest.fixture(params=["diagonal_normal", "diagonal_student_t"])
35+
@pytest.fixture()
36+
def mixture():
37+
from bayesflow.distributions import DiagonalNormal, DiagonalStudentT, Mixture
38+
39+
return Mixture([DiagonalNormal(), DiagonalStudentT(df=25)])
40+
41+
42+
@pytest.fixture(params=["diagonal_normal", "diagonal_student_t", "mixture"])
3643
def distribution(request):
3744
return request.getfixturevalue(request.param)
3845

tests/test_distributions/test_distributions.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
import keras
2-
from keras.saving import (
3-
serialize_keras_object as serialize,
4-
deserialize_keras_object as deserialize,
5-
)
61
import pytest
72

3+
import numpy as np
4+
from scipy.stats import norm, multivariate_t
5+
6+
import keras
7+
8+
from bayesflow.distributions import DiagonalNormal, DiagonalStudentT, Mixture
9+
from bayesflow.utils.serialization import serialize, deserialize
10+
811

912
def test_sample_output_shape(distribution, shape):
1013
distribution.build(shape)
@@ -18,6 +21,42 @@ def test_log_prob_output_shape(distribution, random_samples):
1821
assert keras.ops.shape(log_prob) == keras.ops.shape(random_samples)[:1]
1922

2023

24+
def test_log_prob_correctness(distribution, random_samples):
25+
distribution.build(keras.ops.shape(random_samples))
26+
log_prob = distribution.log_prob(random_samples)
27+
log_prob = keras.ops.convert_to_numpy(log_prob)
28+
random_samples = keras.ops.convert_to_numpy(random_samples)
29+
30+
if isinstance(distribution, DiagonalNormal):
31+
loc = keras.ops.convert_to_numpy(distribution.mean)
32+
scale = keras.ops.convert_to_numpy(distribution.std)
33+
log_prob_scipy = norm(loc=loc, scale=scale).logpdf(random_samples)
34+
log_prob_scipy = log_prob_scipy.sum(axis=-1)
35+
36+
elif isinstance(distribution, DiagonalStudentT):
37+
loc = keras.ops.convert_to_numpy(distribution.loc)
38+
scale = keras.ops.convert_to_numpy(distribution.scale)
39+
df = distribution.df
40+
log_prob_scipy = multivariate_t(loc=loc, shape=np.diag(scale**2), df=df).logpdf(random_samples)
41+
42+
elif isinstance(distribution, Mixture):
43+
loc = keras.ops.convert_to_numpy(distribution.distributions[0].mean)
44+
scale = keras.ops.convert_to_numpy(distribution.distributions[0].std)
45+
log_prob_norm_scipy = norm(loc=loc, scale=scale).logpdf(random_samples)
46+
log_prob_norm_scipy = log_prob_norm_scipy.sum(axis=-1)
47+
48+
loc = keras.ops.convert_to_numpy(distribution.distributions[1].loc)
49+
scale = keras.ops.convert_to_numpy(distribution.distributions[1].scale)
50+
df = distribution.distributions[1].df
51+
log_prob_t_scipy = multivariate_t(loc=loc, shape=np.diag(scale**2), df=df).logpdf(random_samples)
52+
log_prob_scipy = np.log(0.5 * np.exp(log_prob_norm_scipy) + 0.5 * np.exp(log_prob_t_scipy))
53+
54+
else:
55+
raise RuntimeError("distribution must be in '[DiagonalNormal, DiagonalStudentT, Mixture]'")
56+
57+
assert np.allclose(log_prob, log_prob_scipy)
58+
59+
2160
@pytest.mark.parametrize("automatic", [True, False])
2261
def test_build(automatic, distribution, random_samples):
2362
assert distribution.built is False

0 commit comments

Comments
 (0)