Skip to content

Commit 9fb0e1b

Browse files
committed
rename precision_chol to precision_cholesky_factor
to improve clarity.
1 parent 2d9c967 commit 9fb0e1b

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

bayesflow/scores/multivariate_normal_score.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
1818
the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`.
1919
"""
2020

21-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_chol",)
21+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor",)
2222
"""
2323
Marks head for precision matrix Cholesky factor as an exception for adapter transformations.
2424
@@ -28,7 +28,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
2828
For more information see :py:class:`ScoringRule`.
2929
"""
3030

31-
TRANSFORMATION_TYPE: dict[str, str] = {"precision_chol": "right_side_scale_inverse"}
31+
TRANSFORMATION_TYPE: dict[str, str] = {"precision_cholesky_factor": "right_side_scale_inverse"}
3232
"""
3333
Marks precision Cholesky factor head to handle de-standardization appropriately.
3434
@@ -41,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4141
super().__init__(links=links, **kwargs)
4242

4343
self.dim = dim
44-
self.links = links or {"precision_chol": CholeskyFactor()}
44+
self.links = links or {"precision_cholesky_factor": CholeskyFactor()}
4545

4646
self.config = {"dim": dim}
4747

@@ -51,9 +51,9 @@ def get_config(self):
5151

5252
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
5353
self.dim = target_shape[-1]
54-
return dict(mean=(self.dim,), precision_chol=(self.dim, self.dim))
54+
return dict(mean=(self.dim,), precision_cholesky_factor=(self.dim, self.dim))
5555

56-
def log_prob(self, x: Tensor, mean: Tensor, precision_chol: Tensor) -> Tensor:
56+
def log_prob(self, x: Tensor, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
5757
"""
5858
Compute the log probability density of a multivariate Gaussian distribution.
5959
@@ -70,7 +70,7 @@ def log_prob(self, x: Tensor, mean: Tensor, precision_chol: Tensor) -> Tensor:
7070
The shape should be compatible with broadcasting against `mean`.
7171
mean : Tensor
7272
A tensor representing the mean of the multivariate Gaussian distribution.
73-
precision_chol : Tensor
73+
precision_cholesky_factor : Tensor
7474
A tensor representing the lower-triangular Cholesky factor of the precision matrix
7575
of the multivariate Gaussian distribution.
7676
@@ -84,19 +84,21 @@ def log_prob(self, x: Tensor, mean: Tensor, precision_chol: Tensor) -> Tensor:
8484

8585
# Compute log determinant, exploiting Cholesky factors
8686
log_det_covariance = -2 * keras.ops.sum(
87-
keras.ops.log(keras.ops.diagonal(precision_chol, axis1=1, axis2=2)), axis=1
87+
keras.ops.log(keras.ops.diagonal(precision_cholesky_factor, axis1=1, axis2=2)), axis=1
8888
)
8989

9090
# Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91-
# diff^T * precision_chol^T * precision_chol * diff
92-
quadratic_term = keras.ops.einsum("...i,...ji,...jk,...k->...", diff, precision_chol, precision_chol, diff)
91+
# diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff
92+
quadratic_term = keras.ops.einsum(
93+
"...i,...ji,...jk,...k->...", diff, precision_cholesky_factor, precision_cholesky_factor, diff
94+
)
9395

9496
# Compute the log probability density
9597
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)
9698

9799
return log_prob
98100

99-
def sample(self, batch_shape: Shape, mean: Tensor, precision_chol: Tensor) -> Tensor:
101+
def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
100102
"""
101103
Generate samples from a multivariate Gaussian distribution.
102104
@@ -110,7 +112,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_chol: Tensor) -> Te
110112
mean : Tensor
111113
A tensor representing the mean of the multivariate Gaussian distribution.
112114
Must have shape (batch_size, D), where D is the dimensionality of the distribution.
113-
precision_chol : Tensor
115+
precision_cholesky_factor : Tensor
114116
A tensor representing the lower-triangular Cholesky factor of the precision matrix
115117
of the multivariate Gaussian distribution.
116118
Must have shape (batch_size, D, D), where D is the dimensionality.
@@ -120,18 +122,18 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_chol: Tensor) -> Te
120122
Tensor
121123
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
122124
"""
123-
cov_chol = keras.ops.inv(precision_chol)
125+
cov_chol = keras.ops.inv(precision_cholesky_factor)
124126
if len(batch_shape) == 1:
125127
batch_shape = (1,) + tuple(batch_shape)
126128
batch_size, num_samples = batch_shape
127129
dim = keras.ops.shape(mean)[-1]
128130
if keras.ops.shape(mean) != (batch_size, dim):
129131
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")
130132

131-
if keras.ops.shape(precision_chol) != (batch_size, dim, dim):
133+
if keras.ops.shape(precision_cholesky_factor) != (batch_size, dim, dim):
132134
raise ValueError(
133135
f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim}),"
134-
f"but got {keras.ops.shape(precision_chol)}"
136+
f"but got {keras.ops.shape(precision_cholesky_factor)}"
135137
)
136138

137139
# Use Cholesky decomposition to generate samples

0 commit comments

Comments
 (0)