11from collections .abc import Sequence
22from abc import ABC , abstractmethod
3+ from typing import Union
34import keras
45from keras import ops
56import warnings
@@ -61,7 +62,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
6162 pass
6263
6364 @abstractmethod
64- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
65+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
6566 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6667 pass
6768
@@ -141,7 +142,7 @@ class LinearNoiseSchedule(NoiseSchedule):
141142 """
142143
143144 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
144- super ().__init__ (name = "linear_noise_schedule" )
145+ super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" )
145146 self ._log_snr_min = min_log_snr
146147 self ._log_snr_max = max_log_snr
147148
@@ -154,7 +155,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
154155 # SNR = -log(exp(t^2) - 1)
155156 return - ops .log (ops .exp (ops .square (t_trunc )) - 1 )
156157
157- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
158+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
158159 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
159160 # SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
160161 return ops .sqrt (ops .log (1 + ops .exp (- log_snr_t )))
@@ -213,7 +214,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
213214 # SNR = -2 * log(tan(pi*t/2))
214215 return - 2 * ops .log (ops .tan (math .pi * t_trunc / 2 )) + 2 * self ._s_shift_cosine
215216
216- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
217+ def get_t_from_log_snr (self , log_snr_t : Union [ Tensor , float ] , training : bool ) -> Tensor :
217218 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
218219 # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
219220 return 2 / math .pi * ops .arctan (ops .exp ((2 * self ._s_shift_cosine - log_snr_t ) / 2 ))
@@ -289,7 +290,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
289290 )
290291 return snr
291292
292- def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
293+ def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
293294 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
294295 if training :
295296 # SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
@@ -551,8 +552,8 @@ def _forward(
551552 ) -> Tensor | tuple [Tensor , Tensor ]:
552553 integrate_kwargs = (
553554 {
554- "start_time" : self . noise_schedule . _t_min ,
555- "stop_time" : self . noise_schedule . _t_max ,
555+ "start_time" : 1.0 ,
556+ "stop_time" : 0.0 ,
556557 }
557558 | self .integrate_kwargs
558559 | kwargs
@@ -600,8 +601,8 @@ def _inverse(
600601 ) -> Tensor | tuple [Tensor , Tensor ]:
601602 integrate_kwargs = (
602603 {
603- "start_time" : self . noise_schedule . _t_max ,
604- "stop_time" : self . noise_schedule . _t_min ,
604+ "start_time" : 1.0 ,
605+ "stop_time" : 0.0 ,
605606 }
606607 | self .integrate_kwargs
607608 | kwargs
0 commit comments