@@ -57,18 +57,18 @@ def __init__(
5757 self .trainable_parameters = trainable_parameters
5858 self .seed_generator = seed_generator or keras .random .SeedGenerator ()
5959
60- self .dim = None
60+ self .dims = None
6161 self ._mean = None
6262 self ._std = None
6363
6464 def build (self , input_shape : Shape ) -> None :
6565 if self .built :
6666 return
6767
68- self .dim = int ( input_shape [- 1 ])
68+ self .dims = input_shape [1 :]
6969
70- self .mean = ops .cast (ops .broadcast_to (self .mean , ( self .dim ,) ), "float32" )
71- self .std = ops .cast (ops .broadcast_to (self .std , ( self .dim ,) ), "float32" )
70+ self .mean = ops .cast (ops .broadcast_to (self .mean , self .dims ), "float32" )
71+ self .std = ops .cast (ops .broadcast_to (self .std , self .dims ), "float32" )
7272
7373 if self .trainable_parameters :
7474 self ._mean = self .add_weight (
@@ -91,14 +91,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191 result = - 0.5 * ops .sum ((samples - self ._mean ) ** 2 / self ._std ** 2 , axis = - 1 )
9292
9393 if normalize :
94- log_normalization_constant = - 0.5 * self .dim * math .log (2.0 * math .pi ) - ops .sum (ops .log (self ._std ))
94+ log_normalization_constant = - 0.5 * ops .sum (self .dims ) * math .log (2.0 * math .pi ) - ops .sum (
95+ ops .log (self ._std )
96+ )
9597 result += log_normalization_constant
9698
9799 return result
98100
99101 @allow_batch_size
100102 def sample (self , batch_shape : Shape ) -> Tensor :
101- return self ._mean + self ._std * keras .random .normal (shape = batch_shape + ( self .dim ,) , seed = self .seed_generator )
103+ return self ._mean + self ._std * keras .random .normal (shape = batch_shape + self .dims , seed = self .seed_generator )
102104
103105 def get_config (self ):
104106 base_config = super ().get_config ()
0 commit comments