@@ -18,7 +18,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
1818 the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`.
1919 """
2020
21- NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_chol " ,)
21+ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor " ,)
2222 """
2323 Marks head for precision matrix Cholesky factor as an exception for adapter transformations.
2424
@@ -28,7 +28,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
2828 For more information see :py:class:`ScoringRule`.
2929 """
3030
31- TRANSFORMATION_TYPE : dict [str , str ] = {"precision_chol " : "right_side_scale_inverse" }
31+ TRANSFORMATION_TYPE : dict [str , str ] = {"precision_cholesky_factor " : "right_side_scale_inverse" }
3232 """
3333 Marks precision Cholesky factor head to handle de-standardization appropriately.
3434
@@ -41,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4141 super ().__init__ (links = links , ** kwargs )
4242
4343 self .dim = dim
44- self .links = links or {"precision_chol " : CholeskyFactor ()}
44+ self .links = links or {"precision_cholesky_factor " : CholeskyFactor ()}
4545
4646 self .config = {"dim" : dim }
4747
@@ -51,9 +51,9 @@ def get_config(self):
5151
5252 def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
5353 self .dim = target_shape [- 1 ]
54- return dict (mean = (self .dim ,), precision_chol = (self .dim , self .dim ))
54+ return dict (mean = (self .dim ,), precision_cholesky_factor = (self .dim , self .dim ))
5555
56- def log_prob (self , x : Tensor , mean : Tensor , precision_chol : Tensor ) -> Tensor :
56+ def log_prob (self , x : Tensor , mean : Tensor , precision_cholesky_factor : Tensor ) -> Tensor :
5757 """
5858 Compute the log probability density of a multivariate Gaussian distribution.
5959
@@ -70,7 +70,7 @@ def log_prob(self, x: Tensor, mean: Tensor, precision_chol: Tensor) -> Tensor:
7070 The shape should be compatible with broadcasting against `mean`.
7171 mean : Tensor
7272 A tensor representing the mean of the multivariate Gaussian distribution.
73- precision_chol : Tensor
73+ precision_cholesky_factor : Tensor
7474 A tensor representing the lower-triangular Cholesky factor of the precision matrix
7575 of the multivariate Gaussian distribution.
7676
@@ -84,19 +84,21 @@ def log_prob(self, x: Tensor, mean: Tensor, precision_chol: Tensor) -> Tensor:
8484
8585 # Compute log determinant, exploiting Cholesky factors
8686 log_det_covariance = - 2 * keras .ops .sum (
87- keras .ops .log (keras .ops .diagonal (precision_chol , axis1 = 1 , axis2 = 2 )), axis = 1
87+ keras .ops .log (keras .ops .diagonal (precision_cholesky_factor , axis1 = 1 , axis2 = 2 )), axis = 1
8888 )
8989
9090 # Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91- # diff^T * precision_chol^T * precision_chol * diff
92- quadratic_term = keras .ops .einsum ("...i,...ji,...jk,...k->..." , diff , precision_chol , precision_chol , diff )
91+ # diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff
92+ quadratic_term = keras .ops .einsum (
93+ "...i,...ji,...jk,...k->..." , diff , precision_cholesky_factor , precision_cholesky_factor , diff
94+ )
9395
9496 # Compute the log probability density
9597 log_prob = - 0.5 * (self .dim * keras .ops .log (2 * math .pi ) + log_det_covariance + quadratic_term )
9698
9799 return log_prob
98100
99- def sample (self , batch_shape : Shape , mean : Tensor , precision_chol : Tensor ) -> Tensor :
101+ def sample (self , batch_shape : Shape , mean : Tensor , precision_cholesky_factor : Tensor ) -> Tensor :
100102 """
101103 Generate samples from a multivariate Gaussian distribution.
102104
@@ -110,7 +112,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_chol: Tensor) -> Te
110112 mean : Tensor
111113 A tensor representing the mean of the multivariate Gaussian distribution.
112114 Must have shape (batch_size, D), where D is the dimensionality of the distribution.
113- precision_chol : Tensor
115+ precision_cholesky_factor : Tensor
114116 A tensor representing the lower-triangular Cholesky factor of the precision matrix
115117 of the multivariate Gaussian distribution.
116118 Must have shape (batch_size, D, D), where D is the dimensionality.
@@ -120,18 +122,18 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_chol: Tensor) -> Te
120122 Tensor
121123 A tensor of shape (batch_size, num_samples, D) containing the generated samples.
122124 """
123- cov_chol = keras .ops .inv (precision_chol )
125+ cov_chol = keras .ops .inv (precision_cholesky_factor )
124126 if len (batch_shape ) == 1 :
125127 batch_shape = (1 ,) + tuple (batch_shape )
126128 batch_size , num_samples = batch_shape
127129 dim = keras .ops .shape (mean )[- 1 ]
128130 if keras .ops .shape (mean ) != (batch_size , dim ):
129131 raise ValueError (f"mean must have shape (batch_size, { dim } ), but got { keras .ops .shape (mean )} " )
130132
131- if keras .ops .shape (precision_chol ) != (batch_size , dim , dim ):
133+ if keras .ops .shape (precision_cholesky_factor ) != (batch_size , dim , dim ):
132134 raise ValueError (
133135 f"covariance Cholesky factor must have shape (batch_size, { dim } , { dim } ),"
134- f"but got { keras .ops .shape (precision_chol )} "
136+ f"but got { keras .ops .shape (precision_cholesky_factor )} "
135137 )
136138
137139 # Use Cholesky decomposition to generate samples
0 commit comments