-
Couldn't load subscription status.
- Fork 6.5k
Check correct model type is passed to from_pretrained
#10189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
2fd4508
78c6e68
185a78f
679c18c
6aad7a7
c1db3bd
50b740a
b8fa81a
44f24a4
5af8c7f
baea141
c5e1e2d
dba12b6
99b0f92
3a43c8a
803e33f
c81415b
13a824e
5679067
24d79a3
3f841d5
f18687f
87f8f03
56ac8b4
87dcf54
e193563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1802,6 +1802,17 @@ def test_pipe_same_device_id_offload(self): | |||||||||||||||||||||||||||||||||
| sd.maybe_free_model_hooks() | ||||||||||||||||||||||||||||||||||
| assert sd._offload_gpu_id == 5 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def test_wrong_model(self): | ||||||||||||||||||||||||||||||||||
| tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||||||||||||||||||||||||||||||||||
| with self.assertRaises(ValueError) as error_context: | ||||||||||||||||||||||||||||||||||
| _ = StableDiffusionPipeline.from_pretrained( | ||||||||||||||||||||||||||||||||||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| assert "Expected" in str(error_context.exception) | ||||||||||||||||||||||||||||||||||
| assert "text_encoder" in str(error_context.exception) | ||||||||||||||||||||||||||||||||||
| assert f"{tokenizer.__class__.__name}" in str(error_context.exception) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| class KarrasDiffusionSchedulers(Enum): | |
| DDIMScheduler = 1 | |
| DDPMScheduler = 2 | |
| PNDMScheduler = 3 | |
| LMSDiscreteScheduler = 4 | |
| EulerDiscreteScheduler = 5 | |
| HeunDiscreteScheduler = 6 | |
| EulerAncestralDiscreteScheduler = 7 | |
| DPMSolverMultistepScheduler = 8 | |
| DPMSolverSinglestepScheduler = 9 | |
| KDPM2DiscreteScheduler = 10 | |
| KDPM2AncestralDiscreteScheduler = 11 | |
| DEISMultistepScheduler = 12 | |
| UniPCMultistepScheduler = 13 | |
| DPMSolverSDEScheduler = 14 | |
| EDMEulerScheduler = 15 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool. But I don't see the flow matching schedulers here. So, if I do assign a text encoder to scheduler in an RF pipeline (FluxPipeline, for example), would it still work as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that also works, for pipelines like Flux we're getting the type
<class 'diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler'>For SD etc we get the enum
<enum 'KarrasDiffusionSchedulers'>[<KarrasDiffusionSchedulers.DDIMScheduler: 1>, <KarrasDiffusionSchedulers.DDPMScheduler: 2>, <KarrasDiffusionSchedulers.PNDMScheduler: 3>, <KarrasDiffusionSchedulers.LMSDiscreteScheduler: 4>, <KarrasDiffusionSchedulers.EulerDiscreteScheduler: 5>, <KarrasDiffusionSchedulers.HeunDiscreteScheduler: 6>, <KarrasDiffusionSchedulers.EulerAncestralDiscreteScheduler: 7>, <KarrasDiffusionSchedulers.DPMSolverMultistepScheduler: 8>, <KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler: 9>, <KarrasDiffusionSchedulers.KDPM2DiscreteScheduler: 10>, <KarrasDiffusionSchedulers.KDPM2AncestralDiscreteScheduler: 11>, <KarrasDiffusionSchedulers.DEISMultistepScheduler: 12>, <KarrasDiffusionSchedulers.UniPCMultistepScheduler: 13>, <KarrasDiffusionSchedulers.DPMSolverSDEScheduler: 14>, <KarrasDiffusionSchedulers.EDMEulerScheduler: 15>]So we apply the same processing (str, split, strip applies for type case) to get a list of scheduler
['FlowMatchEulerDiscreteScheduler']['DDIMScheduler', 'DDPMScheduler', 'PNDMScheduler', 'LMSDiscreteScheduler', 'EulerDiscreteScheduler', 'HeunDiscreteScheduler', 'EulerAncestralDiscreteScheduler', 'DPMSolverMultistepScheduler', 'DPMSolverSinglestepScheduler', 'KDPM2DiscreteScheduler', 'KDPM2AncestralDiscreteScheduler', 'DEISMultistepScheduler', 'UniPCMultistepScheduler', 'DPMSolverSDEScheduler', 'EDMEulerScheduler']If it's not a scheduler it will raise or if it's the wrong type of scheduler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining! Works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now also support Union, context is failed test test_load_connected_checkpoint_with_passed_obj for KandinskyV22CombinedPipeline, we also change scheduler type to Union[DDPMScheduler, UnCLIPScheduler], the test is actually for passing obj to submodels, but changing the scheduler is how that test works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for wrong scheduler are added.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already have the types extracted in
expected_typescan't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#10189 (comment)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to make this check more agnostic to the component names.
We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.
Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)
I would try something like:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this for scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need a value error here, a warning is enough, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a warning then:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).
Here
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290
if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.