We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bddae79 commit 97ac1a4Copy full SHA for 97ac1a4
fastchat/train/llama_flash_attn_monkey_patch.py
@@ -73,7 +73,7 @@ def forward(
73
)
74
output = output.view(bsz, q_len, -1)
75
else:
76
- qvk = qkv.reshape(bsz, q_len, -1)
+ qkv = qkv.reshape(bsz, q_len, -1)
77
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
78
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
79
output_unpad = flash_attn_varlen_qkvpacked_func(
0 commit comments