@@ -146,8 +146,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
146146 self ._log_snr_min = min_log_snr
147147 self ._log_snr_max = max_log_snr
148148
149- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
150- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
149+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
150+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
151151
152152 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
153153 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -205,8 +205,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
205205 self ._log_snr_max = max_log_snr
206206 self ._s_shift_cosine = s_shift_cosine
207207
208- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
209- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
208+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
209+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
210210
211211 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
212212 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -266,8 +266,8 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
266266 # convert EDM parameters to signal-to-noise ratio formulation
267267 self ._log_snr_min = - 2 * ops .log (sigma_max )
268268 self ._log_snr_max = - 2 * ops .log (sigma_min )
269- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
270- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
269+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
270+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
271271
272272 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
273273 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -478,7 +478,7 @@ def convert_prediction_to_x(
478478 if self .prediction_type == "v" :
479479 # convert v into x
480480 x = alpha_t * z - sigma_t * pred
481- elif self .prediction_type == "e " :
481+ elif self .prediction_type == "eps " :
482482 # convert noise prediction into x
483483 x = (z - sigma_t * pred ) / alpha_t
484484 elif self .prediction_type == "x" :
@@ -552,8 +552,8 @@ def _forward(
552552 ) -> Tensor | tuple [Tensor , Tensor ]:
553553 integrate_kwargs = (
554554 {
555- "start_time" : 1 .0 ,
556- "stop_time" : 0 .0 ,
555+ "start_time" : 0 .0 ,
556+ "stop_time" : 1 .0 ,
557557 }
558558 | self .integrate_kwargs
559559 | kwargs
@@ -601,8 +601,8 @@ def _inverse(
601601 ) -> Tensor | tuple [Tensor , Tensor ]:
602602 integrate_kwargs = (
603603 {
604- "start_time" : 0 .0 ,
605- "stop_time" : 1 .0 ,
604+ "start_time" : 1 .0 ,
605+ "stop_time" : 0 .0 ,
606606 }
607607 | self .integrate_kwargs
608608 | kwargs
0 commit comments