Skip to content

Commit 97652d2

Browse files
authored
Add explicit casting in apply_rope for Qwen VL (#9759)
1 parent bd1d9bc commit 97652d2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

comfy/text_encoders/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
128128

129129

130130
def apply_rope(xq, xk, freqs_cis):
131+
org_dtype = xq.dtype
131132
cos = freqs_cis[0]
132133
sin = freqs_cis[1]
133134
q_embed = (xq * cos) + (rotate_half(xq) * sin)
134135
k_embed = (xk * cos) + (rotate_half(xk) * sin)
135-
return q_embed, k_embed
136+
return q_embed.to(org_dtype), k_embed.to(org_dtype)
136137

137138

138139
class Attention(nn.Module):

0 commit comments

Comments
 (0)