Skip to content

Commit 735969c

Browse files
committed
Generalize sample shape to arbitrary N-D arrays
1 parent ac0461a commit 735969c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,18 +382,18 @@ def _sample(
382382
if inference_conditions is None:
383383
inference_conditions = summary_outputs
384384
else:
385-
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)
385+
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
386386

387387
if inference_conditions is not None:
388-
# conditions must always have shape (batch_size, dims)
388+
# conditions must always have shape (batch_size, ..., dims)
389389
batch_size = keras.ops.shape(inference_conditions)[0]
390390
inference_conditions = keras.ops.expand_dims(inference_conditions, axis=1)
391391
inference_conditions = keras.ops.broadcast_to(
392392
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
393393
)
394-
batch_shape = (batch_size, num_samples)
394+
batch_shape = keras.ops.shape(inference_conditions)[:-1]
395395
else:
396-
batch_shape = (num_samples,)
396+
batch_shape = keras.ops.shape(inference_conditions)[1:-1]
397397

398398
return self.inference_network.sample(
399399
batch_shape,

0 commit comments

Comments
 (0)