Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,10 @@ def _sample(
inference_conditions = keras.ops.broadcast_to(
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
)
batch_shape = keras.ops.shape(inference_conditions)[:-1]
batch_shape = (
batch_size,
num_samples,
)
else:
batch_shape = (num_samples,)

Expand Down
12 changes: 6 additions & 6 deletions bayesflow/distributions/diagonal_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ def __init__(
self.trainable_parameters = trainable_parameters
self.seed_generator = seed_generator or keras.random.SeedGenerator()

self.dim = None
self.dims = None
self._mean = None
self._std = None

def build(self, input_shape: Shape) -> None:
if self.built:
return

self.dim = int(input_shape[-1])
self.dims = tuple(input_shape[1:])

self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32")
self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32")
self.mean = ops.cast(ops.broadcast_to(self.mean, self.dims), "float32")
self.std = ops.cast(ops.broadcast_to(self.std, self.dims), "float32")

if self.trainable_parameters:
self._mean = self.add_weight(
Expand All @@ -91,14 +91,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)

if normalize:
log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
log_normalization_constant = -0.5 * sum(self.dims) * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
result += log_normalization_constant

return result

@allow_batch_size
def sample(self, batch_shape: Shape) -> Tensor:
return self._mean + self._std * keras.random.normal(shape=batch_shape + (self.dim,), seed=self.seed_generator)
return self._mean + self._std * keras.random.normal(shape=batch_shape + self.dims, seed=self.seed_generator)

def get_config(self):
base_config = super().get_config()
Expand Down
Loading