-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning #12814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning #12814
Conversation
|
Looks good to me! |
|
Thanks for the quick fix! I didn't have time to submit a PR myself, so I really appreciate you jumping on this. 🙏 |
|
@yiyixuxu @sayakpaul |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Could you also provide your testing script?
The verification script is already provided in the PR description above. from diffusers.models.transformers import transformer_kandinsky
print("Import successful.")Should print a UserWarning on main, but not on this branch. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||
| def forward(self, time): | ||
| args = torch.outer(time, self.freqs.to(device=time.device)) | ||
| time = time.to(dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| time = time.to(dtype=torch.float32) | |
| origintal_dtype = time.dtype | |
| time = time.to(dtype=torch.float32) |
| freqs = self.freqs.to(device=time.device, dtype=torch.float32) | ||
| args = torch.outer(time, freqs) | ||
| time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | ||
| time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) | |
| time_embed = time_embed.to(dtype=original_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I cast to self.in_layer.weight.dtype instead of original_dtype is to prevent runtime crashes on backends like XPU as mentioned by @vladmandic here.
If users load the pipeline in float16, and we pass time_embed as float32, that will raise an error, won't it?
I might be wrong, correct me if so.
|
|
||
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||
| def forward(self, x): | ||
| x = x.to(dtype=self.out_layer.weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm actually this did not look correct to me - we want to upcast it to float32, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, if we force x to float32 here, we might hit the same mismatch crash if the out_layer weights are float16/bfloat16.
What does this PR do?
Fixes #12809
This PR fixes it by:
Verification
I verified that the results remain stable before and after this change by generating images with a fixed seed (
generator=torch.manual_seed(42)).The results are almost the same with some minor differences.
Reproduction Script
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@yiyixuxu @leffff
Anyone in the community is free to review the PR once the tests have passed.