|
46 | 46 | from ..models.attention_processor import FusedAttnProcessor2_0 |
47 | 47 | from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin |
48 | 48 | from ..quantizers.bitsandbytes.utils import _check_bnb_status |
49 | | -from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
| 49 | +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin |
50 | 50 | from ..utils import ( |
51 | 51 | CONFIG_NAME, |
52 | 52 | DEPRECATED_REVISION_ARGS, |
@@ -834,28 +834,31 @@ def load_module(name, value): |
834 | 834 | return True |
835 | 835 |
|
836 | 836 | init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} |
837 | | - scheduler_types = None |
838 | | - if "scheduler" in expected_types: |
839 | | - scheduler_types = [] |
840 | | - for scheduler_type in expected_types["scheduler"]: |
841 | | - if isinstance(scheduler_type, enum.EnumMeta): |
842 | | - scheduler_types.extend(list(scheduler_type)) |
843 | | - else: |
844 | | - scheduler_types.extend([str(scheduler_type)]) |
845 | | - scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] |
846 | 837 |
|
847 | | - for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): |
| 838 | + for key in init_dict.keys(): |
848 | 839 | if key not in passed_class_obj: |
849 | 840 | continue |
850 | | - class_name = passed_class_obj[key].__class__.__name__ |
851 | | - class_name = class_name[4:] if class_name.startswith("Flax") else class_name |
852 | | - expected_class_name = ( |
853 | | - expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name |
854 | | - ) |
855 | | - if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types: |
856 | | - raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.") |
857 | | - elif key != "scheduler" and class_name != expected_class_name: |
858 | | - raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") |
| 841 | + |
| 842 | + class_obj = passed_class_obj[key] |
| 843 | + _expected_class_types = [] |
| 844 | + for expected_type in expected_types[key]: |
| 845 | + if isinstance(expected_type, enum.EnumMeta): |
| 846 | + _expected_class_types.extend(expected_type.__members__.keys()) |
| 847 | + else: |
| 848 | + _expected_class_types.append(expected_type.__name__) |
| 849 | + |
| 850 | + _is_valid_type = class_obj.__class__.__name__ in _expected_class_types |
| 851 | + if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: |
| 852 | + _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) |
| 853 | + _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ |
| 854 | + if _requires_flow_match and not _is_flow_match: |
| 855 | + raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") |
| 856 | + elif not _requires_flow_match and _is_flow_match: |
| 857 | + raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") |
| 858 | + elif not _is_valid_type: |
| 859 | + raise ValueError( |
| 860 | + f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." |
| 861 | + ) |
859 | 862 |
|
860 | 863 | # Special case: safety_checker must be loaded separately when using `from_flax` |
861 | 864 | if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: |
|
0 commit comments