-
Couldn't load subscription status.
- Fork 6.5k
Check correct model type is passed to from_pretrained
#10189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
2fd4508
78c6e68
185a78f
679c18c
6aad7a7
c1db3bd
50b740a
b8fa81a
44f24a4
5af8c7f
baea141
c5e1e2d
dba12b6
99b0f92
3a43c8a
803e33f
c81415b
13a824e
5679067
24d79a3
3f841d5
f18687f
87f8f03
56ac8b4
87dcf54
e193563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -833,6 +833,14 @@ def load_module(name, value): | |
|
|
||
| init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} | ||
|
|
||
| for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): | ||
| if key not in passed_class_obj or key == "scheduler": | ||
| continue | ||
|
||
| class_name = passed_class_obj[key].__class__.__name__ | ||
| class_name = class_name[4:] if class_name.startswith("Flax") else class_name | ||
| if class_name != expected_class_name: | ||
| raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") | ||
|
|
||
| # Special case: safety_checker must be loaded separately when using `from_flax` | ||
| if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: | ||
| raise NotImplementedError( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already have the types extracted in
expected_typescan't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#10189 (comment)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to make this check more agnostic to the component names.
We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.
Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)
I would try something like:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this for scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need a value error here, a warning is enough, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a warning then:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).
Here
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290
if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.