Skip to content

Commit fdac500

Browse files
committed
fix: FM, base_distribution.sample has no argument 'seed'
1 parent 9071be4 commit fdac500

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,10 @@ def compute_metrics(
191191
else:
192192
# not pre-configured, resample
193193
x1 = x
194-
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)
194+
if not self.base_distribution.built:
195+
# ensure that base distribution is built
196+
self.base_distribution.build(keras.ops.shape(x1))
197+
x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
195198

196199
if self.use_optimal_transport:
197200
x1, x0, conditions = optimal_transport(

0 commit comments

Comments
 (0)