Skip to content

Commit 459a0cb

Browse files
committed
override FlowMatch with pipeline from_pretrained
1 parent b34539e commit 459a0cb

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/configuration_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def extract_init_dict(cls, config_dict, **kwargs):
485485
# Skip keys that were not present in the original config, so default __init__ values were used
486486
used_defaults = config_dict.get("_use_default_values", [])
487487
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
488+
if (
489+
"scheduler" in config_dict
490+
and isinstance(config_dict["scheduler"], list)
491+
and config_dict["scheduler"][1].startswith("FlowMatch")
492+
):
493+
config_dict["scheduler"][1] = config_dict["scheduler"][1].replace("FlowMatch", "")
488494

489495
# 0. Copy origin config dict
490496
original_dict = dict(config_dict.items())
@@ -522,6 +528,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
522528

523529
# remove attributes from orig class that cannot be expected
524530
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
531+
if orig_cls_name.startswith("FlowMatch"):
532+
orig_cls_name = orig_cls_name.replace("FlowMatch", "")
525533
if (
526534
isinstance(orig_cls_name, str)
527535
and orig_cls_name != cls.__name__

0 commit comments

Comments
 (0)