Skip to content

Commit eebf950

Browse files
committed
Allow estimation of univariate MVN
1 parent fbc01f5 commit eebf950

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def call(
132132
if xz is None and not self.built:
133133
raise ValueError("Cannot build inference network without inference variables.")
134134
if conditions is None: # unconditional estimation uses a fixed input vector
135-
conditions = keras.ops.convert_to_tensor([[1.0]], dtype=keras.ops.dtype(xz))
135+
conditions = keras.ops.convert_to_tensor([[1.0]])
136136

137137
# pass conditions to the shared subnet
138138
output = self.subnet(conditions, training=training)

bayesflow/scores/multivariate_normal_score.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
101101
Tensor
102102
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
103103
"""
104+
if len(batch_shape) == 1:
105+
batch_shape = (1,) + batch_shape
104106
batch_size, num_samples = batch_shape
105107
dim = keras.ops.shape(mean)[-1]
106108
if keras.ops.shape(mean) != (batch_size, dim):

0 commit comments

Comments
 (0)