Skip to content

Commit e55631d

Browse files
committed
fix batch_shape in sample
1 parent c684bca commit e55631d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,9 @@ def _sample(
535535
inference_conditions = keras.ops.broadcast_to(
536536
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
537537
)
538-
batch_shape = (
539-
batch_size,
540-
num_samples,
541-
)
538+
539+
target_dim = self.inference_network.base_distribution.dims
540+
batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)]
542541
else:
543542
batch_shape = (num_samples,)
544543

0 commit comments

Comments
 (0)