@@ -87,6 +87,7 @@ def __init__(
8787 lower_order_final : bool = True ,
8888 euler_at_final : bool = False ,
8989 final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
90+ use_flow_sigmas : bool = False ,
9091 ):
9192 if solver_type not in ["midpoint" , "heun" ]:
9293 if solver_type in ["logrho" , "bh1" , "bh2" ]:
@@ -152,23 +153,19 @@ def precondition_noise(self, sigma):
152153 if not isinstance (sigma , torch .Tensor ):
153154 sigma = torch .tensor ([sigma ])
154155
155- return sigma .atan () / math .pi * 2
156+ if self .config .use_flow_sigmas :
157+ c_noise = sigma / (sigma + 1 )
158+ else :
159+ c_noise = sigma .atan () / math .pi * 2
160+
161+ return c_noise
156162
157163 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
158164 def precondition_outputs (self , sample , model_output , sigma ):
159- sigma_data = self .config .sigma_data
160- c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
161-
162- if self .config .prediction_type == "epsilon" :
163- c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
164- elif self .config .prediction_type == "v_prediction" :
165- c_out = - sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
165+ if self .config .use_flow_sigmas :
166+ return self ._precondition_outputs_flow (sample , model_output , sigma )
166167 else :
167- raise ValueError (f"Prediction type { self .config .prediction_type } is not supported." )
168-
169- denoised = c_skip * sample + c_out * model_output
170-
171- return denoised
168+ return self ._precondition_outputs_edm (sample , model_output , sigma )
172169
173170 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
174171 def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
@@ -570,8 +567,42 @@ def add_noise(
570567
571568 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
572569 def _get_conditioning_c_in (self , sigma ):
573- c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
570+ if self .config .use_flow_sigmas :
571+ t = sigma / (sigma + 1 )
572+ c_in = 1.0 - t
573+ else :
574+ c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
574575 return c_in
575576
577+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_flow
578+ def _precondition_outputs_flow (self , sample , model_output , sigma ):
579+ t = sigma / (sigma + 1 )
580+ c_skip = 1.0 - t
581+
582+ if self .config .prediction_type == "epsilon" :
583+ c_out = - t
584+ elif self .config .prediction_type == "v_prediction" :
585+ c_out = t
586+ else :
587+ raise ValueError (f"Prediction type { self .config .prediction_type } is not supported." )
588+
589+ denoised = c_skip * sample + c_out * model_output
590+ return denoised
591+
592+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._precondition_outputs_edm
593+ def _precondition_outputs_edm (self , sample , model_output , sigma ):
594+ sigma_data = self .config .sigma_data
595+ c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
596+
597+ if self .config .prediction_type == "epsilon" :
598+ c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
599+ elif self .config .prediction_type == "v_prediction" :
600+ c_out = - sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
601+ else :
602+ raise ValueError (f"Prediction type { self .config .prediction_type } is not supported." )
603+
604+ denoised = c_skip * sample + c_out * model_output
605+ return denoised
606+
576607 def __len__ (self ):
577608 return self .config .num_train_timesteps
0 commit comments