@@ -520,18 +520,22 @@ def _sample(
520520 summary_outputs = self .summary_network (
521521 summary_variables , ** filter_kwargs (kwargs , self .summary_network .call )
522522 )
523- inference_conditions = concatenate_valid ((inference_conditions , summary_outputs ), axis = - 1 )
523+
524+ if inference_conditions is None :
525+ inference_conditions = summary_outputs
526+ else :
527+ inference_conditions = keras .ops .concatenate ([inference_conditions , summary_outputs ], axis = - 1 )
524528
525529 if inference_conditions is not None :
526- # conditions must always have shape (batch_size, ...)
530+ # conditions must always have shape (batch_size, ..., dims )
527531 batch_size = keras .ops .shape (inference_conditions )[0 ]
528532 inference_conditions = keras .ops .expand_dims (inference_conditions , axis = 1 )
529533 inference_conditions = keras .ops .broadcast_to (
530534 inference_conditions , (batch_size , num_samples , * keras .ops .shape (inference_conditions )[2 :])
531535 )
532- batch_shape = ( batch_size , num_samples )
536+ batch_shape = keras . ops . shape ( inference_conditions )[: - 1 ]
533537 else :
534- batch_shape = ( num_samples ,)
538+ batch_shape = keras . ops . shape ( inference_conditions )[ 1 : - 1 ]
535539
536540 return self .inference_network .sample (
537541 batch_shape , conditions = inference_conditions , ** filter_kwargs (kwargs , self .inference_network .sample )
0 commit comments