From 82b3ab4088a82b3bc951175a0e9d5f9ee971a77a Mon Sep 17 00:00:00 2001 From: arrjon Date: Sat, 6 Sep 2025 11:31:27 +0200 Subject: [PATCH 1/6] allow tensor in DiagonalNormal dimension --- bayesflow/distributions/diagonal_normal.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 6b64445c7..25a7797df 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -57,7 +57,7 @@ def __init__( self.trainable_parameters = trainable_parameters self.seed_generator = seed_generator or keras.random.SeedGenerator() - self.dim = None + self.dims = None self._mean = None self._std = None @@ -65,10 +65,10 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dim = int(input_shape[-1]) + self.dims = input_shape[1:] - self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") - self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") + self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") + self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") if self.trainable_parameters: self._mean = self.add_weight( @@ -91,14 +91,16 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) + log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + ops.log(self._std) + ) result += log_normalization_constant return result @allow_batch_size def sample(self, batch_shape: Shape) -> Tensor: - return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator) + return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator) def get_config(self): base_config = super().get_config() From 8fbf7374ca95826e6f7f954e237c0c8d1a955b95 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:04:54 +0200 Subject: [PATCH 2/6] fix sum dims --- bayesflow/distributions/diagonal_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 25a7797df..9cf068137 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -91,7 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * ops.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( + log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( ops.log(self._std) ) result += log_normalization_constant From 5c27246b3aa48f9c8ba400596200c20ada7251ae Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 15:08:16 +0200 Subject: [PATCH 3/6] fix batch_shape for sample --- bayesflow/approximators/continuous_approximator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index fb2e95a56..f27a612f0 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,7 +535,10 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = keras.ops.shape(inference_conditions)[:-1] + batch_shape = ( + batch_size, + num_samples, + ) else: batch_shape = (num_samples,) From c684bcace2add939945da94af16f3de8b0f9cdc9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Sun, 7 Sep 2025 17:55:10 +0200 Subject: [PATCH 4/6] dims to tuple --- bayesflow/distributions/diagonal_normal.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index 9cf068137..83b3e556b 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -65,7 +65,7 @@ def build(self, input_shape: Shape) -> None: if self.built: return - self.dims = input_shape[1:] + self.dims = tuple(input_shape[1:]) self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32") self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32") @@ -91,9 +91,7 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - log_normalization_constant = -0.5 * np.sum(self.dims) * math.log(2.0 * math.pi) - ops.sum( - ops.log(self._std) - ) + log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) result += log_normalization_constant return result From e55631dcec299cecc0e1ce27ebf0cd668ab960ba Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:30:35 +0200 Subject: [PATCH 5/6] fix batch_shape in sample --- bayesflow/approximators/continuous_approximator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index f27a612f0..b60f4e4bd 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -535,10 +535,9 @@ def _sample( inference_conditions = keras.ops.broadcast_to( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - batch_shape = ( - batch_size, - num_samples, - ) + + target_dim = self.inference_network.base_distribution.dims + batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] else: batch_shape = (num_samples,) From 3eaff24a1d314435040a688559e3c85d5235f9c6 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 9 Sep 2025 15:35:56 +0200 Subject: [PATCH 6/6] fix batch_shape for point approximator --- bayesflow/approximators/continuous_approximator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index b60f4e4bd..13ba32cb9 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -536,8 +536,12 @@ def _sample( inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:]) ) - target_dim = self.inference_network.base_distribution.dims - batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)] + if hasattr(self.inference_network, "base_distribution"): + target_shape_len = len(self.inference_network.base_distribution.dims) + else: + # point approximator has no base_distribution + target_shape_len = 1 + batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len] else: batch_shape = (num_samples,)