diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index cf1d10c36..cea95a203 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -81,7 +81,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.req_manager.req_to_token_indexs, self.b_req_idx, self.b_seq_len, - self.b_start_loc, + kv_starts, self.max_len_in_batch, kv_indices, ) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 6a7c28b39..b55c17afd 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -30,7 +30,7 @@ def __init__(self, model): self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) self.max_seq_length = model.max_seq_length self.kv_indices_buffer = [ torch.empty(