Skip to content

Commit 82b3ab4

Browse files
committed
allow tensor in DiagonalNormal dimension
1 parent c21fa4a commit 82b3ab4

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

bayesflow/distributions/diagonal_normal.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ def __init__(
5757
self.trainable_parameters = trainable_parameters
5858
self.seed_generator = seed_generator or keras.random.SeedGenerator()
5959

60-
self.dim = None
60+
self.dims = None
6161
self._mean = None
6262
self._std = None
6363

6464
def build(self, input_shape: Shape) -> None:
6565
if self.built:
6666
return
6767

68-
self.dim = int(input_shape[-1])
68+
self.dims = input_shape[1:]
6969

70-
self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32")
71-
self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32")
70+
self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32")
71+
self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32")
7272

7373
if self.trainable_parameters:
7474
self._mean = self.add_weight(
@@ -91,14 +91,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)
9292

9393
if normalize:
94-
log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
94+
log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(
95+
ops.log(self._std)
96+
)
9597
result += log_normalization_constant
9698

9799
return result
98100

99101
@allow_batch_size
100102
def sample(self, batch_shape: Shape) -> Tensor:
101-
return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)
103+
return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator)
102104

103105
def get_config(self):
104106
base_config = super().get_config()

0 commit comments

Comments
 (0)