|
2 | 2 | import torch |
3 | 3 | import numpy as np |
4 | 4 | import torch.distributed as dist |
5 | | -from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo |
| 5 | +from lightllm.common.basemodel.infer_struct import InferStateInfo |
| 6 | +from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo |
6 | 7 | from lightllm.utils.envs_utils import get_env_start_args |
7 | 8 | from lightllm.utils.dist_utils import get_current_device_id |
8 | 9 | from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index |
9 | 10 | from lightllm.common.basemodel.batch_objs import ModelInput |
10 | 11 |
|
11 | 12 |
|
12 | | -class Qwen2VLFlashAttentionStateInfo(Qwen2VLInferStateInfo): |
13 | | - _shared_page_table_buffer = None |
14 | | - |
15 | | - def __init__(self): |
16 | | - super().__init__() |
17 | | - |
18 | | - @classmethod |
19 | | - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): |
20 | | - if cls._shared_page_table_buffer is None: |
21 | | - cls._shared_page_table_buffer = [ |
22 | | - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), |
23 | | - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), |
24 | | - ] |
25 | | - return cls._shared_page_table_buffer |
26 | | - |
| 13 | +class Qwen2VLFlashAttentionStateInfo(FlashAttentionStateInfo): |
27 | 14 | def init_some_extra_state(self, model, input_ids: torch.Tensor): |
28 | | - super().init_some_extra_state(model, input_ids) |
| 15 | + InferStateInfo.init_some_extra_state(self, model, input_ids) |
29 | 16 | if self.is_prefill: |
30 | | - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() |
31 | | - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() |
32 | | - self.page_table = torch.empty( |
33 | | - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device |
34 | | - ) |
35 | | - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) |
| 17 | + self.max_seq_len = self.max_kv_seq_len |
| 18 | + self.q_max_seq_len = self.max_q_seq_len |
| 19 | + position_ids = self.position_ids |
| 20 | + self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1) |
| 21 | + self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1) |
| 22 | + position_ids = None |
36 | 23 | else: |
37 | | - # Meta information of flashattention for decoding |
38 | | - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() |
39 | | - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() |
40 | | - max_seq_len_k = self.max_kv_seq_len |
41 | | - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: |
42 | | - page_buffer = Qwen2VLFlashAttentionStateInfo.get_page_table_buffer( |
43 | | - model.graph_max_batch_size, model.graph_max_len_in_batch |
44 | | - ) |
45 | | - self.page_table = page_buffer[self.microbatch_index][ |
46 | | - : self.batch_size * model.graph_max_len_in_batch |
47 | | - ].reshape(self.batch_size, model.graph_max_len_in_batch) |
48 | | - else: |
49 | | - self.page_table = torch.empty( |
50 | | - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device |
51 | | - ) |
52 | | - |
53 | | - self.page_table[:, :max_seq_len_k].copy_( |
54 | | - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], |
55 | | - non_blocking=True, |
56 | | - ) |
57 | | - self.page_table[:, max_seq_len_k:].fill_(0) |
58 | | - |
59 | | - if "offline_calibration_fp8kv" in model.mode: |
60 | | - if self.is_prefill: |
61 | | - device = input_ids.device |
62 | | - # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 |
63 | | - self.q_scale = torch.empty( |
64 | | - (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device |
65 | | - ) |
66 | | - self.token_batch_ids = torch.repeat_interleave( |
67 | | - torch.arange(self.batch_size, device=device), self.b_q_seq_len |
68 | | - ) |
| 24 | + position_ids = self.position_ids |
| 25 | + self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1) |
| 26 | + self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1) |
69 | 27 |
|
70 | | - offline_scales = self.mem_manager.scales |
71 | | - head_num = self.mem_manager.head_num |
72 | | - # 为了减少推理计算量,在推理外部初始化k_descale和v_descale |
73 | | - self.k_descale = ( |
74 | | - offline_scales[:, :head_num] |
75 | | - .view(-1, 1, head_num) |
76 | | - .expand(offline_scales.shape[0], self.batch_size, head_num) |
77 | | - if offline_scales is not None |
78 | | - else torch.ones( |
79 | | - (self.mem_manager.layer_num, self.batch_size, head_num), |
80 | | - dtype=torch.float32, |
81 | | - device=input_ids.device, |
82 | | - ) |
83 | | - ) |
84 | | - self.v_descale = ( |
85 | | - offline_scales[:, head_num:] |
86 | | - .view(-1, 1, head_num) |
87 | | - .expand(offline_scales.shape[0], self.batch_size, head_num) |
88 | | - if offline_scales is not None |
89 | | - else torch.ones( |
90 | | - (self.mem_manager.layer_num, self.batch_size, head_num), |
91 | | - dtype=torch.float32, |
92 | | - device=input_ids.device, |
93 | | - ) |
94 | | - ) |
| 28 | + # init flash attention state |
| 29 | + self._init_flash_attention_state(model, input_ids) |
95 | 30 | return |
0 commit comments