1- import keras
2- from keras .saving import (
3- serialize_keras_object as serialize ,
4- deserialize_keras_object as deserialize ,
5- )
61import pytest
72
3+ import numpy as np
4+ from scipy .stats import norm , multivariate_t
5+
6+ import keras
7+
8+ from bayesflow .distributions import DiagonalNormal , DiagonalStudentT , Mixture
9+ from bayesflow .utils .serialization import serialize , deserialize
10+
811
912def test_sample_output_shape (distribution , shape ):
1013 distribution .build (shape )
@@ -18,6 +21,42 @@ def test_log_prob_output_shape(distribution, random_samples):
1821 assert keras .ops .shape (log_prob ) == keras .ops .shape (random_samples )[:1 ]
1922
2023
24+ def test_log_prob_correctness (distribution , random_samples ):
25+ distribution .build (keras .ops .shape (random_samples ))
26+ log_prob = distribution .log_prob (random_samples )
27+ log_prob = keras .ops .convert_to_numpy (log_prob )
28+ random_samples = keras .ops .convert_to_numpy (random_samples )
29+
30+ if isinstance (distribution , DiagonalNormal ):
31+ loc = keras .ops .convert_to_numpy (distribution .mean )
32+ scale = keras .ops .convert_to_numpy (distribution .std )
33+ log_prob_scipy = norm (loc = loc , scale = scale ).logpdf (random_samples )
34+ log_prob_scipy = log_prob_scipy .sum (axis = - 1 )
35+
36+ elif isinstance (distribution , DiagonalStudentT ):
37+ loc = keras .ops .convert_to_numpy (distribution .loc )
38+ scale = keras .ops .convert_to_numpy (distribution .scale )
39+ df = distribution .df
40+ log_prob_scipy = multivariate_t (loc = loc , shape = np .diag (scale ** 2 ), df = df ).logpdf (random_samples )
41+
42+ elif isinstance (distribution , Mixture ):
43+ loc = keras .ops .convert_to_numpy (distribution .distributions [0 ].mean )
44+ scale = keras .ops .convert_to_numpy (distribution .distributions [0 ].std )
45+ log_prob_norm_scipy = norm (loc = loc , scale = scale ).logpdf (random_samples )
46+ log_prob_norm_scipy = log_prob_norm_scipy .sum (axis = - 1 )
47+
48+ loc = keras .ops .convert_to_numpy (distribution .distributions [1 ].loc )
49+ scale = keras .ops .convert_to_numpy (distribution .distributions [1 ].scale )
50+ df = distribution .distributions [1 ].df
51+ log_prob_t_scipy = multivariate_t (loc = loc , shape = np .diag (scale ** 2 ), df = df ).logpdf (random_samples )
52+ log_prob_scipy = np .log (0.5 * np .exp (log_prob_norm_scipy ) + 0.5 * np .exp (log_prob_t_scipy ))
53+
54+ else :
55+ raise RuntimeError ("distribution must be in '[DiagonalNormal, DiagonalStudentT, Mixture]'" )
56+
57+ assert np .allclose (log_prob , log_prob_scipy )
58+
59+
2160@pytest .mark .parametrize ("automatic" , [True , False ])
2261def test_build (automatic , distribution , random_samples ):
2362 assert distribution .built is False
0 commit comments