@@ -261,7 +261,11 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
261261 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
262262 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
263263 if training :
264- snr = - icdf_gaussian (x = t_trunc , loc = - 2 * self .p_mean , scale = 2 * self .p_std )
264+ # SNR = -dist.icdf(t_trunc)
265+ loc = - 2 * self .p_mean
266+ scale = 2 * self .p_std
267+ x = t_trunc
268+ snr = - (loc + scale * ops .erfinv (2 * x - 1 ) * math .sqrt (2 ))
265269 snr = keras .ops .clip (snr , x_min = self ._log_snr_min , x_max = self ._log_snr_max )
266270 else : # sampling
267271 snr = (
@@ -278,7 +282,10 @@ def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
278282 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
279283 if training :
280284 # SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
281- t = cdf_gaussian (x = - log_snr_t , loc = - 2 * self .p_mean , scale = 2 * self .p_std )
285+ loc = - 2 * self .p_mean
286+ scale = 2 * self .p_std
287+ x = - log_snr_t
288+ t = 0.5 * (1 + ops .erf ((x - loc ) / (scale * math .sqrt (2.0 ))))
282289 else : # sampling
283290 # SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
284291 # => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
@@ -632,8 +639,11 @@ def compute_metrics(
632639 conditions_shape = None if conditions is None else keras .ops .shape (conditions )
633640 self .build (xz_shape , conditions_shape )
634641
635- # sample training diffusion time
636- t = keras .random .uniform ((keras .ops .shape (x )[0 ],))
642+ # sample training diffusion time as low discrepancy sequence to decrease variance
643+ # t_i = \mod (u_0 + i/k, 1)
644+ u0 = keras .random .uniform (shape = (1 ,))
645+ i = ops .arange (0 , keras .ops .shape (x )[0 ]) # tensor of indices
646+ t = (u0 + i / keras .ops .shape (x )[0 ]) % 1
637647 # i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)
638648 # t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
639649
0 commit comments