Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,13 @@ def load_module(name, value):

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
Copy link
Collaborator

@DN6 DN6 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the types extracted in expected_types can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@DN6 DN6 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to make this check more agnostic to the component names.

We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.

from diffusers import (
	AnimateDiffPipeline,
	UNetMotionModel,
)

unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
	"hf-internal-testing/tiny-sd-pipe", unet=unet
)

Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)

I would try something like:

        for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
            if key not in passed_class_obj:
                continue

            class_obj = passed_class_obj[key]
            _expected_class_types = []
            for expected_type in expected_types[key]:
                if isinstance(expected_type, enum.EnumMeta):
                    _expected_class_types.extend(expected_type.__members__.keys())
                else:
                    _expected_class_types.append(expected_type.__name__)

            _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
            if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
                # Handle case where scheduler is still valid 
                # raise if scheduler is meant to be a Flow based scheduler?
            elif not _is_valid_type:
                raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")

Copy link
Contributor Author

@hlky hlky Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this for scheduler

                _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
                _is_flow_match = "FlowMatch" in class_obj.__class__.__name__
                if _requires_flow_match and not _is_flow_match:
                    raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
                elif not _requires_flow_match and _is_flow_match:
                    raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a value error here, a warning is enough, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a warning then:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).

Here

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290

if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.

if key not in passed_class_obj:
continue
class_name = passed_class_obj[key].__class__.__name__
if class_name != expected_class_name:
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test for this too?


# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
Expand Down
Loading