Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/diffusers/models/transformers/transformer_kandinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,12 @@ def __init__(self, model_dim, time_dim, max_period=10000.0):
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)

@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time = time.to(dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time = time.to(dtype=torch.float32)
origintal_dtype = time.dtype
time = time.to(dtype=torch.float32)

freqs = self.freqs.to(device=time.device, dtype=torch.float32)
args = torch.outer(time, freqs)
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time_embed = time_embed.to(dtype=self.in_layer.weight.dtype)
time_embed = time_embed.to(dtype=original_dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I cast to self.in_layer.weight.dtype instead of original_dtype is to prevent runtime crashes on backends like XPU as mentioned by @vladmandic here.
If users load the pipeline in float16, and we pass time_embed as float32, that will raise an error, won't it?
I might be wrong, correct me if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I also tried to apply your suggested change and ran the exact same code attached in the PR description above on an L40S (which supports bf16), and faced below error:

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/diffusers/verify_fix.py", line 25, in <module>
    output = pipe(
             ^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py", line 731, in __call__
    pred_velocity = self.transformer(
                    ^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 647, in forward
    text_embed = text_transformer_block(text_embed, time_embed, text_rope)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 461, in forward
    self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_kandinsky.py", line 280, in forward
    return self.out_layer(self.activation(x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Thus, I think casting to self.in_layer.weight.dtype is a safer option.
Please let me know your thoughts.

time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed

Expand Down Expand Up @@ -269,8 +271,8 @@ def __init__(self, time_dim, model_dim, num_params):
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()

@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
x = x.to(dtype=self.out_layer.weight.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm actually this did not look correct to me - we want to upcast it to float32, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, if we force x to float32 here, we might hit the same mismatch crash if the out_layer weights are float16/bfloat16.

return self.out_layer(self.activation(x))


Expand Down