Skip to content

Commit c684bca

Browse files
committed
dims to tuple
1 parent 5c27246 commit c684bca

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

bayesflow/distributions/diagonal_normal.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def build(self, input_shape: Shape) -> None:
6565
if self.built:
6666
return
6767

68-
self.dims = input_shape[1:]
68+
self.dims = tuple(input_shape[1:])
6969

7070
self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32")
7171
self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32")
@@ -91,9 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)
9292

9393
if normalize:
94-
log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(
95-
ops.log(self._std)
96-
)
94+
log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
9795
result += log_normalization_constant
9896

9997
return result

0 commit comments

Comments
 (0)