-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Describe the bug
In WanRotaryPosEmbed (link), we are splitting the attention_head_dim into the different dimensions in different ways in __init__ and forward. This causes a missmatch depending on the attention_head_dim. This issue is also presentin other models that use rotary (e.g., Skyreels_v2).
Details
Given
diffusers/src/diffusers/models/transformers/transformer_wan.py
Lines 363 to 364 in 9c3b58d
| h_dim = w_dim = 2 * (attention_head_dim // 6) | |
| t_dim = attention_head_dim - h_dim - w_dim |
if we have an attention_head_dim=64, we get:
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
print([t_dim, h_dim, w_dim])
printing [24, 20, 20]
In the forward, when spliting the dimensions, we have
diffusers/src/diffusers/models/transformers/transformer_wan.py
Lines 390 to 394 in 9c3b58d
| split_sizes = [ | |
| self.attention_head_dim - 2 * (self.attention_head_dim // 3), | |
| self.attention_head_dim // 3, | |
| self.attention_head_dim // 3, | |
| ] |
so if we try
split_sizes = [
attention_head_dim - 2 * (attention_head_dim // 3),
attention_head_dim // 3,
attention_head_dim // 3,
]
print(split_sizes)
printing [22, 21, 21]
In most of the models where the attention head is equal to 128 the values match, but I was wondering if this is a bug to fix.
Reproduction
NA
Logs
System Info
NA