@@ -252,7 +252,9 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
252252 use_karras_sigmas = config .pop ("use_karras_sigmas" , None )
253253 use_exponential_sigmas = config .pop ("use_exponential_sigmas" , None )
254254 use_beta_sigmas = config .pop ("use_beta_sigmas" , None )
255+ use_flow_sigmas = config .pop ("use_flow_sigmas" , None )
255256 prediction_type = config .pop ("prediction_type" , None )
257+ schedule_config = {}
256258 if use_karras_sigmas :
257259 sigma_schedule_config = {"class_name" : "KarrasSigmas" }
258260 elif use_exponential_sigmas :
@@ -263,16 +265,32 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
263265 sigma_schedule_config = {}
264266 if "beta_schedule" in config :
265267 config .update ({"class_name" : "BetaSchedule" })
266- elif "shift" in config :
268+ from .schedulers .schedules .beta_schedule import BetaSchedule
269+
270+ expected_kwargs = list (inspect .signature (BetaSchedule .__init__ ).parameters )[1 :- 1 ]
271+ for expected_kwarg in expected_kwargs :
272+ if expected_kwarg in config :
273+ schedule_config [expected_kwarg ] = config .pop (expected_kwarg )
274+ elif "shift" in config or use_flow_sigmas :
267275 config .update ({"class_name" : "FlowMatchSchedule" })
276+ from .schedulers .schedules .flow_schedule import FlowMatchSchedule
277+
278+ expected_kwargs = list (inspect .signature (FlowMatchSchedule .__init__ ).parameters )[1 :- 1 ]
279+ for expected_kwarg in expected_kwargs :
280+ if expected_kwarg in config :
281+ schedule_config [expected_kwarg ] = config .pop (expected_kwarg )
282+ if prediction_type == "flow_prediction" :
283+ prediction_type = "epsilon"
268284 if prediction_type :
269285 config .update ({"prediction_type" : prediction_type })
270- config = {
271- "_class_name" : _class_name ,
272- "_diffusers_version" : _diffusers_version ,
273- "schedule_config" : config ,
274- "sigma_schedule_config" : sigma_schedule_config ,
275- }
286+ config .update (
287+ {
288+ "_class_name" : _class_name ,
289+ "_diffusers_version" : _diffusers_version ,
290+ "schedule_config" : schedule_config ,
291+ "sigma_schedule_config" : sigma_schedule_config ,
292+ }
293+ )
276294
277295 init_dict , unused_kwargs , hidden_dict = cls .extract_init_dict (config , ** kwargs )
278296
0 commit comments