Skip to content

Commit 9e0e04b

Browse files
committed
update
1 parent bc0a787 commit 9e0e04b

File tree

2 files changed

+21
-82
lines changed

2 files changed

+21
-82
lines changed

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
2424
]
2525
return cls._shared_page_table_buffer
2626

27-
def init_some_extra_state(self, model, input_ids: torch.Tensor):
28-
super().init_some_extra_state(model, input_ids)
27+
def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
2928
if self.is_prefill:
3029
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3130
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
@@ -93,3 +92,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
9392
)
9493
)
9594
return
95+
96+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
97+
super().init_some_extra_state(model, input_ids)
98+
self._init_flash_attention_state(model, input_ids)
99+
return

lightllm/models/qwen2_vl/flashattention_infer_struct.py

Lines changed: 15 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,94 +2,29 @@
22
import torch
33
import numpy as np
44
import torch.distributed as dist
5-
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
5+
from lightllm.common.basemodel.infer_struct import InferStateInfo
6+
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
67
from lightllm.utils.envs_utils import get_env_start_args
78
from lightllm.utils.dist_utils import get_current_device_id
89
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
910
from lightllm.common.basemodel.batch_objs import ModelInput
1011

1112

12-
class Qwen2VLFlashAttentionStateInfo(Qwen2VLInferStateInfo):
13-
_shared_page_table_buffer = None
14-
15-
def __init__(self):
16-
super().__init__()
17-
18-
@classmethod
19-
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
20-
if cls._shared_page_table_buffer is None:
21-
cls._shared_page_table_buffer = [
22-
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
23-
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
24-
]
25-
return cls._shared_page_table_buffer
26-
13+
class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo):
2714
def init_some_extra_state(self, model, input_ids: torch.Tensor):
28-
super().init_some_extra_state(model, input_ids)
15+
InferStateInfo.init_some_extra_state(self, model, input_ids)
2916
if self.is_prefill:
30-
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
31-
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
32-
self.page_table = torch.empty(
33-
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
34-
)
35-
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
17+
self.max_seq_len = self.max_kv_seq_len
18+
self.q_max_seq_len = self.max_q_seq_len
19+
position_ids = self.position_ids
20+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
21+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
22+
position_ids = None
3623
else:
37-
# Meta information of flashattention for decoding
38-
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
39-
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
40-
max_seq_len_k = self.max_kv_seq_len
41-
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42-
page_buffer = Qwen2VLFlashAttentionStateInfo.get_page_table_buffer(
43-
model.graph_max_batch_size, model.graph_max_len_in_batch
44-
)
45-
self.page_table = page_buffer[self.microbatch_index][
46-
: self.batch_size * model.graph_max_len_in_batch
47-
].reshape(self.batch_size, model.graph_max_len_in_batch)
48-
else:
49-
self.page_table = torch.empty(
50-
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
51-
)
52-
53-
self.page_table[:, :max_seq_len_k].copy_(
54-
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k],
55-
non_blocking=True,
56-
)
57-
self.page_table[:, max_seq_len_k:].fill_(0)
58-
59-
if "offline_calibration_fp8kv" in model.mode:
60-
if self.is_prefill:
61-
device = input_ids.device
62-
# q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化
63-
self.q_scale = torch.empty(
64-
(self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device
65-
)
66-
self.token_batch_ids = torch.repeat_interleave(
67-
torch.arange(self.batch_size, device=device), self.b_q_seq_len
68-
)
24+
position_ids = self.position_ids
25+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
26+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
6927

70-
offline_scales = self.mem_manager.scales
71-
head_num = self.mem_manager.head_num
72-
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
73-
self.k_descale = (
74-
offline_scales[:, :head_num]
75-
.view(-1, 1, head_num)
76-
.expand(offline_scales.shape[0], self.batch_size, head_num)
77-
if offline_scales is not None
78-
else torch.ones(
79-
(self.mem_manager.layer_num, self.batch_size, head_num),
80-
dtype=torch.float32,
81-
device=input_ids.device,
82-
)
83-
)
84-
self.v_descale = (
85-
offline_scales[:, head_num:]
86-
.view(-1, 1, head_num)
87-
.expand(offline_scales.shape[0], self.batch_size, head_num)
88-
if offline_scales is not None
89-
else torch.ones(
90-
(self.mem_manager.layer_num, self.batch_size, head_num),
91-
dtype=torch.float32,
92-
device=input_ids.device,
93-
)
94-
)
28+
# init flash attention state
29+
self._init_flash_attention_state(model, input_ids)
9530
return

0 commit comments

Comments
 (0)