|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +import torch.distributed as dist |
| 5 | +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo |
| 6 | +from lightllm.utils.envs_utils import get_env_start_args |
| 7 | +from lightllm.utils.dist_utils import get_current_device_id |
| 8 | +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index |
| 9 | +from lightllm.common.basemodel.batch_objs import ModelInput |
| 10 | + |
| 11 | + |
| 12 | +class Qwen2VLFlashAttentionStateInfo(Qwen2VLInferStateInfo): |
| 13 | + _shared_page_table_buffer = None |
| 14 | + |
| 15 | + def __init__(self): |
| 16 | + super().__init__() |
| 17 | + |
| 18 | + @classmethod |
| 19 | + def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): |
| 20 | + if cls._shared_page_table_buffer is None: |
| 21 | + cls._shared_page_table_buffer = [ |
| 22 | + torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), |
| 23 | + torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), |
| 24 | + ] |
| 25 | + return cls._shared_page_table_buffer |
| 26 | + |
| 27 | + def init_some_extra_state(self, model, input_ids: torch.Tensor): |
| 28 | + super().init_some_extra_state(model, input_ids) |
| 29 | + if self.is_prefill: |
| 30 | + self.cu_seqlens_q = self.b1_cu_q_seq_len.int() |
| 31 | + self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() |
| 32 | + self.page_table = torch.empty( |
| 33 | + (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device |
| 34 | + ) |
| 35 | + self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) |
| 36 | + else: |
| 37 | + # Meta information of flashattention for decoding |
| 38 | + self.cu_seqlens_q = self.b1_cu_q_seq_len.int() |
| 39 | + self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() |
| 40 | + max_seq_len_k = self.max_kv_seq_len |
| 41 | + if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: |
| 42 | + page_buffer = Qwen2VLFlashAttentionStateInfo.get_page_table_buffer( |
| 43 | + model.graph_max_batch_size, model.graph_max_len_in_batch |
| 44 | + ) |
| 45 | + self.page_table = page_buffer[self.microbatch_index][ |
| 46 | + : self.batch_size * model.graph_max_len_in_batch |
| 47 | + ].reshape(self.batch_size, model.graph_max_len_in_batch) |
| 48 | + else: |
| 49 | + self.page_table = torch.empty( |
| 50 | + (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device |
| 51 | + ) |
| 52 | + |
| 53 | + self.page_table[:, :max_seq_len_k].copy_( |
| 54 | + model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], |
| 55 | + non_blocking=True, |
| 56 | + ) |
| 57 | + self.page_table[:, max_seq_len_k:].fill_(0) |
| 58 | + |
| 59 | + if "offline_calibration_fp8kv" in model.mode: |
| 60 | + if self.is_prefill: |
| 61 | + device = input_ids.device |
| 62 | + # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 |
| 63 | + self.q_scale = torch.empty( |
| 64 | + (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device |
| 65 | + ) |
| 66 | + self.token_batch_ids = torch.repeat_interleave( |
| 67 | + torch.arange(self.batch_size, device=device), self.b_q_seq_len |
| 68 | + ) |
| 69 | + |
| 70 | + offline_scales = self.mem_manager.scales |
| 71 | + head_num = self.mem_manager.head_num |
| 72 | + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale |
| 73 | + self.k_descale = ( |
| 74 | + offline_scales[:, :head_num] |
| 75 | + .view(-1, 1, head_num) |
| 76 | + .expand(offline_scales.shape[0], self.batch_size, head_num) |
| 77 | + if offline_scales is not None |
| 78 | + else torch.ones( |
| 79 | + (self.mem_manager.layer_num, self.batch_size, head_num), |
| 80 | + dtype=torch.float32, |
| 81 | + device=input_ids.device, |
| 82 | + ) |
| 83 | + ) |
| 84 | + self.v_descale = ( |
| 85 | + offline_scales[:, head_num:] |
| 86 | + .view(-1, 1, head_num) |
| 87 | + .expand(offline_scales.shape[0], self.batch_size, head_num) |
| 88 | + if offline_scales is not None |
| 89 | + else torch.ones( |
| 90 | + (self.mem_manager.layer_num, self.batch_size, head_num), |
| 91 | + dtype=torch.float32, |
| 92 | + device=input_ids.device, |
| 93 | + ) |
| 94 | + ) |
| 95 | + return |
0 commit comments