11from collections .abc import Sequence
22from abc import ABC , abstractmethod
3+ from typing import Union
34import keras
45from keras import ops
56
@@ -60,7 +61,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
6061 pass
6162
6263 @abstractmethod
63- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
64+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
6465 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6566 pass
6667
@@ -140,7 +141,7 @@ class LinearNoiseSchedule(NoiseSchedule):
140141 """
141142
142143 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
143- super ().__init__ (name = "linear_noise_schedule" )
144+ super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" )
144145 self ._log_snr_min = min_log_snr
145146 self ._log_snr_max = max_log_snr
146147
@@ -153,7 +154,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
153154 # SNR = -log(exp(t^2) - 1)
154155 return - ops .log (ops .exp (ops .square (t_trunc )) - 1 )
155156
156- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
157+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
157158 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
158159 # SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
159160 return ops .sqrt (ops .log (1 + ops .exp (- log_snr_t )))
@@ -212,7 +213,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
212213 # SNR = -2 * log(tan(pi*t/2))
213214 return - 2 * ops .log (ops .tan (math .pi * t_trunc / 2 )) + 2 * self ._s_shift_cosine
214215
215- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
216+ def get_t_from_log_snr (self , log_snr_t : Union [ Tensor , float ] , training : bool ) -> Tensor :
216217 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
217218 # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
218219 return 2 / math .pi * ops .arctan (ops .exp ((2 * self ._s_shift_cosine - log_snr_t ) / 2 ))
@@ -288,7 +289,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
288289 )
289290 return snr
290291
291- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
292+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
292293 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
293294 if training :
294295 # SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
@@ -543,8 +544,8 @@ def _forward(
543544 ) -> Tensor | tuple [Tensor , Tensor ]:
544545 integrate_kwargs = (
545546 {
546- "start_time" : self . noise_schedule . _t_min ,
547- "stop_time" : self . noise_schedule . _t_max ,
547+ "start_time" : 1.0 ,
548+ "stop_time" : 0.0 ,
548549 }
549550 | self .integrate_kwargs
550551 | kwargs
@@ -592,8 +593,8 @@ def _inverse(
592593 ) -> Tensor | tuple [Tensor , Tensor ]:
593594 integrate_kwargs = (
594595 {
595- "start_time" : self . noise_schedule . _t_max ,
596- "stop_time" : self . noise_schedule . _t_min ,
596+ "start_time" : 1.0 ,
597+ "stop_time" : 0.0 ,
597598 }
598599 | self .integrate_kwargs
599600 | kwargs
0 commit comments