Skip to content

Commit 730931a

Browse files
committed
prediction_type
1 parent d9ad3f8 commit 730931a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/configuration_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,12 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
247247

248248
# Handle old scheduler configs
249249
if "Scheduler" in cls.__name__ and "schedule_config" not in config:
250-
prediction_type = config.pop("prediction_type", None)
251250
_class_name = config.pop("_class_name", None)
252251
_diffusers_version = config.pop("_diffusers_version", None)
253252
use_karras_sigmas = config.pop("use_karras_sigmas", None)
254253
use_exponential_sigmas = config.pop("use_exponential_sigmas", None)
255254
use_beta_sigmas = config.pop("use_beta_sigmas", None)
255+
prediction_type = config.pop("prediction_type", None)
256256
if use_karras_sigmas:
257257
sigma_schedule_config = {"class_name": "KarrasSigmas"}
258258
elif use_exponential_sigmas:
@@ -265,10 +265,11 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
265265
config.update({"class_name": "BetaSchedule"})
266266
elif "shift" in config:
267267
config.update({"class_name": "FlowMatchSchedule"})
268+
if prediction_type:
269+
config.update({"prediction_type": prediction_type})
268270
config = {
269271
"_class_name": _class_name,
270272
"_diffusers_version": _diffusers_version,
271-
"prediction_type": prediction_type,
272273
"schedule_config": config,
273274
"sigma_schedule_config": sigma_schedule_config,
274275
}

0 commit comments

Comments
 (0)