Skip to content

Commit 359063f

Browse files
reduce the memory of flashinfer (#850)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 6f2e76f commit 359063f

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

lightllm/models/deepseek2/flashinfer_struct.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2323
if not self.is_prefill:
2424
if get_env_start_args().enable_flashinfer_decode:
2525
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
26-
self.kv_indices = torch.empty(
27-
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
28-
).to(input_ids.device)
26+
if self.batch_size <= model.graph_max_batch_size:
27+
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
28+
: self.batch_size * self.flashinfer_extra_state.max_seq_length
29+
]
30+
else:
31+
self.kv_indices = torch.empty(
32+
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
33+
).to(input_ids.device)
2934
repack_kv_index(
3035
self.req_manager.req_to_token_indexs,
3136
self.b_req_idx,

lightllm/models/deepseek2/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def __init__(self, model):
3131
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(get_current_device_id())
3232
self.max_seq_length = model.max_seq_length
3333
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
34+
self.kv_indices_buffer = [
35+
torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to(
36+
get_current_device_id()
37+
),
38+
torch.empty(model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32).to(
39+
get_current_device_id()
40+
),
41+
]
3442
if model.config["rope_scaling"] is not None:
3543
rope_scaling = model.config["rope_scaling"]
3644
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0)

lightllm/server/api_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def normal_or_p_d_start(args):
142142
else:
143143
# chunked 模式下
144144
if args.batch_max_tokens is None:
145-
args.batch_max_tokens = min(args.max_req_total_len, 2 * args.chunked_prefill_size)
145+
args.batch_max_tokens = min(args.max_req_total_len, 2 * args.chunked_prefill_size + 256)
146146

147147
assert (
148148
args.batch_max_tokens >= args.chunked_prefill_size

0 commit comments

Comments
 (0)