diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index fb2e95a56..13ba32cb9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,7 +535,13 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = keras.ops.shape(inference_conditions)[:-1] + + if hasattr(self.inference_network, "base_distribution"): + target_shape_len = len(self.inference_network.base_distribution.dims) + else: + # point approximator has no base_distribution + target_shape_len = 1 + batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len] else: batch_shape = (num_samples,) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 6b64445c7..83b3e556b 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -57,7 +57,7 @@ def __init__( self.trainable_parameters = trainable_parameters self.seed_generator = seed_generator or keras.random.SeedGenerator() - self.dim = None + self.dims = None self._mean = None self._std = None @@ -65,10 +65,10 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dim = int(input_shape[-1]) + self.dims = tuple(input_shape[1:]) - self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") - self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") + self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") + self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") if self.trainable_parameters: self._mean = self.add_weight( @@ -91,14 +91,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) + log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) result += log_normalization_constant return result @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: - return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator) + return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator) def get_config(self): base_config = super().get_config()