Skip to content

Commit e840046

Browse files
committed
fix backend
1 parent 5f11724 commit e840046

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
261261
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
262262
t_trunc = self._t_min + (self._t_max - self._t_min) * t
263263
if training:
264-
snr = -icdf_gaussian(x=t_trunc, loc=-2 * self.p_mean, scale=2 * self.p_std)
264+
# SNR = -dist.icdf(t_trunc)
265+
loc = -2 * self.p_mean
266+
scale = 2 * self.p_std
267+
x = t_trunc
268+
snr = -(loc + scale * ops.erfinv(2 * x - 1) * math.sqrt(2))
265269
snr = keras.ops.clip(snr, x_min=self._log_snr_min, x_max=self._log_snr_max)
266270
else: # sampling
267271
snr = (
@@ -278,7 +282,10 @@ def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
278282
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
279283
if training:
280284
# SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
281-
t = cdf_gaussian(x=-log_snr_t, loc=-2 * self.p_mean, scale=2 * self.p_std)
285+
loc = -2 * self.p_mean
286+
scale = 2 * self.p_std
287+
x = -log_snr_t
288+
t = 0.5 * (1 + ops.erf((x - loc) / (scale * math.sqrt(2.0))))
282289
else: # sampling
283290
# SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
284291
# => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
@@ -632,8 +639,11 @@ def compute_metrics(
632639
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
633640
self.build(xz_shape, conditions_shape)
634641

635-
# sample training diffusion time
636-
t = keras.random.uniform((keras.ops.shape(x)[0],))
642+
# sample training diffusion time as low discrepancy sequence to decrease variance
643+
# t_i = \mod (u_0 + i/k, 1)
644+
u0 = keras.random.uniform(shape=(1,))
645+
i = ops.arange(0, keras.ops.shape(x)[0]) # tensor of indices
646+
t = (u0 + i / keras.ops.shape(x)[0]) % 1
637647
# i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)
638648
# t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
639649

0 commit comments

Comments
 (0)