Skip to content

Commit bc2bda8

Browse files
authored
allow tensor in DiagonalNormal dimension (#571)
* allow tensor in DiagonalNormal dimension
1 parent d349e63 commit bc2bda8

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,13 @@ def _sample(
535535
inference_conditions = keras.ops.broadcast_to(
536536
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
537537
)
538-
batch_shape = keras.ops.shape(inference_conditions)[:-1]
538+
539+
if hasattr(self.inference_network, "base_distribution"):
540+
target_shape_len = len(self.inference_network.base_distribution.dims)
541+
else:
542+
# point approximator has no base_distribution
543+
target_shape_len = 1
544+
batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len]
539545
else:
540546
batch_shape = (num_samples,)
541547

bayesflow/distributions/diagonal_normal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ def __init__(
5757
self.trainable_parameters = trainable_parameters
5858
self.seed_generator = seed_generator or keras.random.SeedGenerator()
5959

60-
self.dim = None
60+
self.dims = None
6161
self._mean = None
6262
self._std = None
6363

6464
def build(self, input_shape: Shape) -> None:
6565
if self.built:
6666
return
6767

68-
self.dim = int(input_shape[-1])
68+
self.dims = tuple(input_shape[1:])
6969

70-
self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32")
71-
self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32")
70+
self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32")
71+
self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32")
7272

7373
if self.trainable_parameters:
7474
self._mean = self.add_weight(
@@ -91,14 +91,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)
9292

9393
if normalize:
94-
log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
94+
log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
9595
result += log_normalization_constant
9696

9797
return result
9898

9999
@allow_batch_size
100100
def sample(self, batch_shape: Shape) -> Tensor:
101-
return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)
101+
return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator)
102102

103103
def get_config(self):
104104
base_config = super().get_config()

0 commit comments

Comments
 (0)