Skip to content

Commit 3eaff24

Browse files
committed
fix batch_shape for point approximator
1 parent e55631d commit 3eaff24

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,12 @@ def _sample(
536536
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
537537
)
538538

539-
target_dim = self.inference_network.base_distribution.dims
540-
batch_shape = keras.ops.shape(inference_conditions)[: -len(target_dim)]
539+
if hasattr(self.inference_network, "base_distribution"):
540+
target_shape_len = len(self.inference_network.base_distribution.dims)
541+
else:
542+
# point approximator has no base_distribution
543+
target_shape_len = 1
544+
batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len]
541545
else:
542546
batch_shape = (num_samples,)
543547

0 commit comments

Comments
 (0)