@@ -17,7 +17,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
1717 of the materialized value.
1818 """
1919
20- NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol " ,)
20+ NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol_inv " ,)
2121 """
2222 Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
2323
@@ -27,7 +27,7 @@ class MultivariateNormalScore(ParametricDistributionScore):
2727 For more information see :py:class:`ScoringRule`.
2828 """
2929
30- TRANSFORMATION_TYPE : dict [str , str ] = {"cov_chol " : "left_side_scale" }
30+ TRANSFORMATION_TYPE : dict [str , str ] = {"cov_chol_inv " : "left_side_scale" }
3131 """
3232 Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
3333
@@ -42,7 +42,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
4242 super ().__init__ (links = links , ** kwargs )
4343
4444 self .dim = dim
45- self .links = links or {"cov_chol " : CholeskyFactor ()}
45+ self .links = links or {"cov_chol_inv " : CholeskyFactor ()}
4646
4747 self .config = {"dim" : dim }
4848
@@ -52,9 +52,9 @@ def get_config(self):
5252
5353 def get_head_shapes_from_target_shape (self , target_shape : Shape ) -> dict [str , Shape ]:
5454 self .dim = target_shape [- 1 ]
55- return dict (mean = (self .dim ,), cov_chol = (self .dim , self .dim ))
55+ return dict (mean = (self .dim ,), cov_chol_inv = (self .dim , self .dim ))
5656
57- def log_prob (self , x : Tensor , mean : Tensor , cov_chol : Tensor ) -> Tensor :
57+ def log_prob (self , x : Tensor , mean : Tensor , cov_chol_inv : Tensor ) -> Tensor :
5858 """
5959 Compute the log probability density of a multivariate Gaussian distribution.
6060
@@ -82,25 +82,21 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282 """
8383 diff = x - mean
8484
85- # Calculate precision from Cholesky factors of covariance matrix
86- cov_chol_inv = keras .ops .inv (cov_chol )
87- precision = keras .ops .matmul (
88- keras .ops .swapaxes (cov_chol_inv , - 2 , - 1 ),
89- cov_chol_inv ,
90- )
91-
9285 # Compute log determinant, exploiting Cholesky factors
93- log_det_covariance = keras .ops .log (keras .ops .prod (keras .ops .diagonal (cov_chol , axis1 = 1 , axis2 = 2 ), axis = 1 )) * 2
86+ log_det_covariance = - 2 * keras .ops .sum (
87+ keras .ops .log (keras .ops .diagonal (cov_chol_inv , axis1 = 1 , axis2 = 2 )), axis = 1
88+ )
9489
95- # Compute the quadratic term in the exponential of the multivariate Gaussian
96- quadratic_term = keras .ops .einsum ("...i,...ij,...j->..." , diff , precision , diff )
90+ # Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
91+ # diff^T * cov_chol_inv^T * cov_chol_inv * diff
92+ quadratic_term = keras .ops .einsum ("...i,...ji,...jk,...k->..." , diff , cov_chol_inv , cov_chol_inv , diff )
9793
9894 # Compute the log probability density
9995 log_prob = - 0.5 * (self .dim * keras .ops .log (2 * math .pi ) + log_det_covariance + quadratic_term )
10096
10197 return log_prob
10298
103- def sample (self , batch_shape : Shape , mean : Tensor , cov_chol : Tensor ) -> Tensor :
99+ def sample (self , batch_shape : Shape , mean : Tensor , cov_chol_inv : Tensor ) -> Tensor :
104100 """
105101 Generate samples from a multivariate Gaussian distribution.
106102
@@ -123,17 +119,18 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
123119 Tensor
124120 A tensor of shape (batch_size, num_samples, D) containing the generated samples.
125121 """
122+ cov_chol = keras .ops .inv (cov_chol_inv )
126123 if len (batch_shape ) == 1 :
127124 batch_shape = (1 ,) + tuple (batch_shape )
128125 batch_size , num_samples = batch_shape
129126 dim = keras .ops .shape (mean )[- 1 ]
130127 if keras .ops .shape (mean ) != (batch_size , dim ):
131128 raise ValueError (f"mean must have shape (batch_size, { dim } ), but got { keras .ops .shape (mean )} " )
132129
133- if keras .ops .shape (cov_chol ) != (batch_size , dim , dim ):
130+ if keras .ops .shape (cov_chol_inv ) != (batch_size , dim , dim ):
134131 raise ValueError (
135132 f"covariance Cholesky factor must have shape (batch_size, { dim } , { dim } ),"
136- f"but got { keras .ops .shape (cov_chol )} "
133+ f"but got { keras .ops .shape (cov_chol_inv )} "
137134 )
138135
139136 # Use Cholesky decomposition to generate samples
0 commit comments