-
Couldn't load subscription status.
- Fork 6.5k
Fix flux controlnet mode to take into account batch size #9406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
370f382
c333d89
61e0950
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -748,9 +748,11 @@ def __call__( | |||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # set control mode | ||||||||||||||||||
| orig_mode_type = type(control_mode) | ||||||||||||||||||
|
||||||||||||||||||
| orig_mode_type = type(control_mode) | |
| if isinstance(control_mode, list): | |
| control_mode = [control_mode] | |
| if len(control_mode) > 1: | |
| raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list contain 1 `int`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant to do the opposite here? i.e. if not isinstance(control_mode, list) then convert control_mode into a singleton list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops! yes!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1) | |
| if orig_mode_type == int: | |
| control_mode = control_mode.repeat(control_image.shape[0], 1) | |
| control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) | |
| control_model = control_mode.view(-1,1).expand(control_image.shape[0], 1) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure each control_mode has batch_size too for multi-controlnet
| control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) | |
| control_mode = control_mode.view(-1, 1) | |
| else: | |
| raise ValueError("For multi-controlnet, control_mode should be a list") | |
| control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) | |
| control_mode = control_mode.view(-1, 1).expand(control_images[0].shape[0] | |
| else: | |
| raise ValueError("For multi-controlnet, control_mode should be a list") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I made this change as well. (Will commit in a sec)
However, unlike the regular controlnet if block I do not explicitly allow for control_mode to be an int here, i.e. it won't automagically be converted to a singleton list. I think maybe in the "multi" case it's best to be explicit -- and remind the user -- to think in terms of it being a list, even if the multi controlnet object just has one controlnet contained inside it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we only accept int for single controlnet, maybe raise a value error here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There does already appear to be a check for it here inside the controlnet class itself:
diffusers/src/diffusers/models/controlnet_flux.py
Lines 289 to 290 in b52119a