Skip to content

Commit 24d79a3

Browse files
committed
update
1 parent 5679067 commit 24d79a3

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ..models.attention_processor import FusedAttnProcessor2_0
4747
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4848
from ..quantizers.bitsandbytes.utils import _check_bnb_status
49-
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
49+
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
5050
from ..utils import (
5151
CONFIG_NAME,
5252
DEPRECATED_REVISION_ARGS,
@@ -834,28 +834,31 @@ def load_module(name, value):
834834
return True
835835

836836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
837-
scheduler_types = None
838-
if "scheduler" in expected_types:
839-
scheduler_types = []
840-
for scheduler_type in expected_types["scheduler"]:
841-
if isinstance(scheduler_type, enum.EnumMeta):
842-
scheduler_types.extend(list(scheduler_type))
843-
else:
844-
scheduler_types.extend([str(scheduler_type)])
845-
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]
846837

847-
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
838+
for key in init_dict.keys():
848839
if key not in passed_class_obj:
849840
continue
850-
class_name = passed_class_obj[key].__class__.__name__
851-
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
852-
expected_class_name = (
853-
expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name
854-
)
855-
if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types:
856-
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
857-
elif key != "scheduler" and class_name != expected_class_name:
858-
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
841+
842+
class_obj = passed_class_obj[key]
843+
_expected_class_types = []
844+
for expected_type in expected_types[key]:
845+
if isinstance(expected_type, enum.EnumMeta):
846+
_expected_class_types.extend(expected_type.__members__.keys())
847+
else:
848+
_expected_class_types.append(expected_type.__name__)
849+
850+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
851+
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
852+
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
853+
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
854+
if _requires_flow_match and not _is_flow_match:
855+
raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
856+
elif not _requires_flow_match and _is_flow_match:
857+
raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")
858+
elif not _is_valid_type:
859+
raise ValueError(
860+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
861+
)
859862

860863
# Special case: safety_checker must be loaded separately when using `from_flax`
861864
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:

0 commit comments

Comments
 (0)