Skip to content

Commit d8e6280

Browse files
committed
[fix]qwen2vl support fa3
1 parent d0b5fb7 commit d8e6280

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
6+
from lightllm.utils.envs_utils import get_env_start_args
7+
from lightllm.utils.dist_utils import get_current_device_id
8+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
9+
from lightllm.common.basemodel.batch_objs import ModelInput
10+
11+
12+
class Qwen2VLFlashAttentionStateInfo(Qwen2VLInferStateInfo):
13+
_shared_page_table_buffer = None
14+
15+
def __init__(self):
16+
super().__init__()
17+
18+
@classmethod
19+
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
20+
if cls._shared_page_table_buffer is None:
21+
cls._shared_page_table_buffer = [
22+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
23+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
24+
]
25+
return cls._shared_page_table_buffer
26+
27+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
28+
super().init_some_extra_state(model, input_ids)
29+
if self.is_prefill:
30+
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
31+
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
32+
self.page_table = torch.empty(
33+
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
34+
)
35+
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
36+
else:
37+
# Meta information of flashattention for decoding
38+
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
39+
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
40+
max_seq_len_k = self.max_kv_seq_len
41+
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42+
page_buffer = Qwen2VLFlashAttentionStateInfo.get_page_table_buffer(
43+
model.graph_max_batch_size, model.graph_max_len_in_batch
44+
)
45+
self.page_table = page_buffer[self.microbatch_index][
46+
: self.batch_size * model.graph_max_len_in_batch
47+
].reshape(self.batch_size, model.graph_max_len_in_batch)
48+
else:
49+
self.page_table = torch.empty(
50+
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
51+
)
52+
53+
self.page_table[:, :max_seq_len_k].copy_(
54+
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k],
55+
non_blocking=True,
56+
)
57+
self.page_table[:, max_seq_len_k:].fill_(0)
58+
59+
if "offline_calibration_fp8kv" in model.mode:
60+
if self.is_prefill:
61+
device = input_ids.device
62+
# q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化
63+
self.q_scale = torch.empty(
64+
(self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device
65+
)
66+
self.token_batch_ids = torch.repeat_interleave(
67+
torch.arange(self.batch_size, device=device), self.b_q_seq_len
68+
)
69+
70+
offline_scales = self.mem_manager.scales
71+
head_num = self.mem_manager.head_num
72+
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
73+
self.k_descale = (
74+
offline_scales[:, :head_num]
75+
.view(-1, 1, head_num)
76+
.expand(offline_scales.shape[0], self.batch_size, head_num)
77+
if offline_scales is not None
78+
else torch.ones(
79+
(self.mem_manager.layer_num, self.batch_size, head_num),
80+
dtype=torch.float32,
81+
device=input_ids.device,
82+
)
83+
)
84+
self.v_descale = (
85+
offline_scales[:, head_num:]
86+
.view(-1, 1, head_num)
87+
.expand(offline_scales.shape[0], self.batch_size, head_num)
88+
if offline_scales is not None
89+
else torch.ones(
90+
(self.mem_manager.layer_num, self.batch_size, head_num),
91+
dtype=torch.float32,
92+
device=input_ids.device,
93+
)
94+
)
95+
return

0 commit comments

Comments
 (0)