@@ -58,7 +58,7 @@ def log_prob(self, x: Tensor, mean: Tensor, precision_cholesky_factor: Tensor) -
5858 Compute the log probability density of a multivariate Gaussian distribution.
5959
6060 This function calculates the log probability density for each sample in `x` under a
61- multivariate Gaussian distribution with the given `mean` and `cov_chol `.
61+ multivariate Gaussian distribution with the given `mean` and `precision_cholesky_factor `.
6262
6363 The computation includes the determinant of the precision matrix, its inverse, and the quadratic
6464 form in the exponential term of the Gaussian density function.
@@ -122,7 +122,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Te
122122 Tensor
123123 A tensor of shape (batch_size, num_samples, D) containing the generated samples.
124124 """
125- cov_chol = keras .ops .inv (precision_cholesky_factor )
125+ covariance_cholesky_factor = keras .ops .inv (precision_cholesky_factor )
126126 if len (batch_shape ) == 1 :
127127 batch_shape = (1 ,) + tuple (batch_shape )
128128 batch_size , num_samples = batch_shape
@@ -139,7 +139,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Te
139139 # Use Cholesky decomposition to generate samples
140140 normal_samples = keras .random .normal ((* batch_shape , dim ))
141141
142- scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , cov_chol , normal_samples )
142+ scaled_normal = keras .ops .einsum ("ijk,ilk->ilj" , covariance_cholesky_factor , normal_samples )
143143 samples = mean [:, None , :] + scaled_normal
144144
145145 return samples
0 commit comments