Skip to content

Commit c3e945e

Browse files
authored
Merge branch 'dev' into standardize_in_approx
2 parents 43ced5b + 735969c commit c3e945e

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,18 +520,22 @@ def _sample(
520520
summary_outputs = self.summary_network(
521521
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
522522
)
523-
inference_conditions = concatenate_valid((inference_conditions, summary_outputs), axis=-1)
523+
524+
if inference_conditions is None:
525+
inference_conditions = summary_outputs
526+
else:
527+
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
524528

525529
if inference_conditions is not None:
526-
# conditions must always have shape (batch_size, ...)
530+
# conditions must always have shape (batch_size, ..., dims)
527531
batch_size = keras.ops.shape(inference_conditions)[0]
528532
inference_conditions = keras.ops.expand_dims(inference_conditions, axis=1)
529533
inference_conditions = keras.ops.broadcast_to(
530534
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
531535
)
532-
batch_shape = (batch_size, num_samples)
536+
batch_shape = keras.ops.shape(inference_conditions)[:-1]
533537
else:
534-
batch_shape = (num_samples,)
538+
batch_shape = keras.ops.shape(inference_conditions)[1:-1]
535539

536540
return self.inference_network.sample(
537541
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)

0 commit comments

Comments
 (0)