|
1 | 1 | import math |
2 | 2 |
|
3 | 3 | import keras |
| 4 | +from keras.saving import register_keras_serializable as serializable |
4 | 5 |
|
5 | 6 | from bayesflow.types import Shape, Tensor |
6 | 7 | from bayesflow.links import PositiveSemiDefinite |
|
9 | 10 | from .parametric_distribution_score import ParametricDistributionScore |
10 | 11 |
|
11 | 12 |
|
| 13 | +@serializable(package="bayesflow.scores") |
12 | 14 | class MultivariateNormalScore(ParametricDistributionScore): |
13 | 15 | r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = \log( \mathcal N (\theta; \mu, \Sigma))` |
14 | 16 |
|
@@ -96,9 +98,14 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor |
96 | 98 | A tensor of shape (batch_size, num_samples, D) containing the generated samples. |
97 | 99 | """ |
98 | 100 | batch_size, num_samples = batch_shape |
99 | | - dim = mean.shape[-1] |
100 | | - assert mean.shape == (batch_size, dim), "mean must have shape (batch_size, D)" |
101 | | - assert covariance.shape == (batch_size, dim, dim), "covariance must have shape (batch_size, D, D)" |
| 101 | + dim = keras.ops.shape(mean)[-1] |
| 102 | + if keras.ops.shape(mean) != (batch_size, dim): |
| 103 | + raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}") |
| 104 | + |
| 105 | + if keras.ops.shape(covariance) != (batch_size, dim, dim): |
| 106 | + raise ValueError( |
| 107 | + f"covariance must have shape (batch_size, {dim}, {dim}), but got {keras.ops.shape(covariance)}" |
| 108 | + ) |
102 | 109 |
|
103 | 110 | # Use Cholesky decomposition to generate samples |
104 | 111 | cholesky_factor = keras.ops.cholesky(covariance) |
|
0 commit comments