From 58757419014e2d5e09b3fdd678f79c5db14ba1f9 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Wed, 9 Jul 2025 20:14:44 +0800 Subject: [PATCH] fix: fix a bug in flashinfer_struct --- lightllm/models/llama/flashinfer_struct.py | 2 +- lightllm/models/llama/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index cf1d10c36..cea95a203 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -81,7 +81,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.req_manager.req_to_token_indexs, self.b_req_idx, self.b_seq_len, - self.b_start_loc, + kv_starts, self.max_len_in_batch, kv_indices, ) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 6a7c28b39..b55c17afd 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -30,7 +30,7 @@ def __init__(self, model): self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) self.max_seq_length = model.max_seq_length self.kv_indices_buffer = [ torch.empty(