@@ -63,19 +63,19 @@ def __init__(
6363
6464 self .seed_generator = seed_generator or keras .random .SeedGenerator ()
6565
66- self .dim = None
66+ self .dims = None
6767 self ._loc = None
6868 self ._scale = None
6969
7070 def build (self , input_shape : Shape ) -> None :
7171 if self .built :
7272 return
7373
74- self .dim = int (input_shape [- 1 ])
74+ self .dims = tuple (input_shape [1 : ])
7575
7676 # convert to tensor and broadcast if necessary
77- self .loc = ops .cast (ops .broadcast_to (self .loc , ( self .dim ,) ), "float32" )
78- self .scale = ops .cast (ops .broadcast_to (self .scale , ( self .dim ,) ), "float32" )
77+ self .loc = ops .cast (ops .broadcast_to (self .loc , self .dims ), "float32" )
78+ self .scale = ops .cast (ops .broadcast_to (self .scale , self .dims ), "float32" )
7979
8080 if self .trainable_parameters :
8181 self ._loc = self .add_weight (
@@ -96,14 +96,14 @@ def build(self, input_shape: Shape) -> None:
9696
9797 def log_prob (self , samples : Tensor , * , normalize : bool = True ) -> Tensor :
9898 mahalanobis_term = ops .sum ((samples - self ._loc ) ** 2 / self ._scale ** 2 , axis = - 1 )
99- result = - 0.5 * (self .df + self .dim ) * ops .log1p (mahalanobis_term / self .df )
99+ result = - 0.5 * (self .df + sum ( self .dims ) ) * ops .log1p (mahalanobis_term / self .df )
100100
101101 if normalize :
102102 log_normalization_constant = (
103- - 0.5 * self .dim * math .log (self .df )
104- - 0.5 * self .dim * math .log (math .pi )
103+ - 0.5 * sum ( self .dims ) * math .log (self .df )
104+ - 0.5 * sum ( self .dims ) * math .log (math .pi )
105105 - math .lgamma (0.5 * self .df )
106- + math .lgamma (0.5 * (self .df + self .dim ))
106+ + math .lgamma (0.5 * (self .df + sum ( self .dims ) ))
107107 - ops .sum (keras .ops .log (self ._scale ))
108108 )
109109 result += log_normalization_constant
@@ -119,9 +119,10 @@ def sample(self, batch_shape: Shape) -> Tensor:
119119
120120 # The chi-quare samples need to be repeated across self.dim
121121 # since for each element of batch_shape only one sample is created.
122- chi2_samples = expand_tile (chi2_samples , n = self .dim , axis = - 1 )
122+ chi2_samples = expand_tile (chi2_samples , n = sum (self .dims ), axis = - 1 )
123+ chi2_samples = keras .ops .reshape (chi2_samples , batch_shape + self .dims )
123124
124- normal_samples = keras .random .normal (batch_shape + ( self .dim ,) , seed = self .seed_generator )
125+ normal_samples = keras .random .normal (batch_shape + self .dims , seed = self .seed_generator )
125126
126127 return self ._loc + self ._scale * normal_samples * ops .sqrt (self .df / chi2_samples )
127128
0 commit comments