Skip to content

Commit bb67e90

Browse files
committed
breaking: parameterize MVNormalScore by inverse cholesky factor
The log_prob can be completely calculated using the inverse cholesky factor L^{-1}. Using this also stabilizes the initial loss, and speeds up computation. This commit also contains two optimizations. Moving the computation of the precision matrix into the einsum, and using the sum of the logs instead of the log of a product. Open question: Is the transformation behavior "left_side_scale" still correct for the inverse matrix? As the parameterization changes, this is a breaking change. As it resolves major stability problems for higher-dimensional problems, I thing it is worth including them anyway.
1 parent 47d2766 commit bb67e90

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

bayesflow/scores/multivariate_normal_score.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
1717
of the materialized value.
1818
"""
1919

20-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol",)
20+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol_inv",)
2121
"""
2222
Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
2323
@@ -27,7 +27,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
2727
For more information see :py:class:`ScoringRule`.
2828
"""
2929

30-
TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol": "left_side_scale"}
30+
TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol_inv": "left_side_scale"}
3131
"""
3232
Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
3333
@@ -42,7 +42,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4242
super().__init__(links=links, **kwargs)
4343

4444
self.dim = dim
45-
self.links = links or {"cov_chol": CholeskyFactor()}
45+
self.links = links or {"cov_chol_inv": CholeskyFactor()}
4646

4747
self.config = {"dim": dim}
4848

@@ -52,9 +52,9 @@ def get_config(self):
5252

5353
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
5454
self.dim = target_shape[-1]
55-
return dict(mean=(self.dim,), cov_chol=(self.dim, self.dim))
55+
return dict(mean=(self.dim,), cov_chol_inv=(self.dim, self.dim))
5656

57-
def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
57+
def log_prob(self, x: Tensor, mean: Tensor, cov_chol_inv: Tensor) -> Tensor:
5858
"""
5959
Compute the log probability density of a multivariate Gaussian distribution.
6060
@@ -82,25 +82,21 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282
"""
8383
diff = x - mean
8484

85-
# Calculate precision from Cholesky factors of covariance matrix
86-
cov_chol_inv = keras.ops.inv(cov_chol)
87-
precision = keras.ops.matmul(
88-
keras.ops.swapaxes(cov_chol_inv, -2, -1),
89-
cov_chol_inv,
90-
)
91-
9285
# Compute log determinant, exploiting Cholesky factors
93-
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
86+
log_det_covariance = -2 * keras.ops.sum(
87+
keras.ops.log(keras.ops.diagonal(cov_chol_inv, axis1=1, axis2=2)), axis=1
88+
)
9489

95-
# Compute the quadratic term in the exponential of the multivariate Gaussian
96-
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)
90+
# Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91+
# diff^T * cov_chol_inv^T * cov_chol_inv * diff
92+
quadratic_term = keras.ops.einsum("...i,...ji,...jk,...k->...", diff, cov_chol_inv, cov_chol_inv, diff)
9793

9894
# Compute the log probability density
9995
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)
10096

10197
return log_prob
10298

103-
def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
99+
def sample(self, batch_shape: Shape, mean: Tensor, cov_chol_inv: Tensor) -> Tensor:
104100
"""
105101
Generate samples from a multivariate Gaussian distribution.
106102
@@ -123,17 +119,18 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
123119
Tensor
124120
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
125121
"""
122+
cov_chol = keras.ops.inv(cov_chol_inv)
126123
if len(batch_shape) == 1:
127124
batch_shape = (1,) + tuple(batch_shape)
128125
batch_size, num_samples = batch_shape
129126
dim = keras.ops.shape(mean)[-1]
130127
if keras.ops.shape(mean) != (batch_size, dim):
131128
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")
132129

133-
if keras.ops.shape(cov_chol) != (batch_size, dim, dim):
130+
if keras.ops.shape(cov_chol_inv) != (batch_size, dim, dim):
134131
raise ValueError(
135132
f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim}),"
136-
f"but got {keras.ops.shape(cov_chol)}"
133+
f"but got {keras.ops.shape(cov_chol_inv)}"
137134
)
138135

139136
# Use Cholesky decomposition to generate samples

0 commit comments

Comments
 (0)