Skip to content

Commit 9d87656

Browse files
committed
Tuple conversion in case batch_shape is a list
1 parent f1e1ba1 commit 9d87656

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

bayesflow/scores/multivariate_normal_score.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
9999
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
100100
"""
101101
if len(batch_shape) == 1:
102-
batch_shape = (1,) + batch_shape
102+
batch_shape = (1,) + tuple(batch_shape)
103103
batch_size, num_samples = batch_shape
104104
dim = keras.ops.shape(mean)[-1]
105105
if keras.ops.shape(mean) != (batch_size, dim):

0 commit comments

Comments
 (0)