Skip to content

Commit 0f7b3f5

Browse files
committed
add dtypes and type casts in compute_metrics
1 parent bd564b5 commit 0f7b3f5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,9 @@ def compute_metrics(
644644

645645
# sample training diffusion time as low discrepancy sequence to decrease variance
646646
# t_i = \mod (u_0 + i/k, 1)
647-
u0 = keras.random.uniform(shape=(1,))
648-
i = ops.arange(0, keras.ops.shape(x)[0]) # tensor of indices
649-
t = (u0 + i / keras.ops.shape(x)[0]) % 1
647+
u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x))
648+
i = ops.arange(0, keras.ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices
649+
t = (u0 + i / ops.cast(keras.ops.shape(x)[0], dtype=ops.dtype(x))) % 1
650650
# i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)
651651
# t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
652652

0 commit comments

Comments
 (0)