Skip to content

Commit 86d262a

Browse files
authored
Fix dynamic_prompt_cache for chunked prefill (#1010)
1 parent fe76c17 commit 86d262a

File tree

6 files changed

+2
-12
lines changed

6 files changed

+2
-12
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(self, kvargs):
6262
self.is_token_healing = kvargs.get("is_token_healing", False)
6363
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
6464
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
65-
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
6665
self.data_type = kvargs.get("data_type", "float16")
6766
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
6867
self.graph_max_batch_size = (
@@ -251,7 +250,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
251250
infer_state.is_prefill = model_input.is_prefill
252251
infer_state.is_token_healing = self.is_token_healing
253252
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
254-
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
255253
infer_state.batch_size = model_input.batch_size
256254
infer_state.total_token_num = model_input.total_token_num
257255
infer_state.max_len_in_batch = model_input.max_len_in_batch

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def __init__(self):
3535

3636
self.is_token_healing: bool = False
3737
self.return_all_prompt_logics: bool = False
38-
self.use_dynamic_prompt_cache: bool = False
3938
self.multimodal_params: dict = None
4039
self.is_cuda_graph: bool = False # 标记是否是cuda graph的捕获推理
4140
self.dist_group: CustomProcessGroup = None

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _decompress_kv(
253253
b_kv_start_loc,
254254
skip_sample=False,
255255
):
256-
if infer_state.use_dynamic_prompt_cache and not skip_sample:
256+
if not skip_sample:
257257
if is_fp8:
258258
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
259259
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)

lightllm/server/api_start.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,8 @@ def normal_or_p_d_start(args):
9494

9595
if args.graph_max_len_in_batch == 0:
9696
args.graph_max_len_in_batch = args.max_req_total_len
97-
97+
9898
# mode setting check.
99-
if not args.disable_chunked_prefill:
100-
assert args.disable_dynamic_prompt_cache is False
101-
assert args.disable_chunked_prefill is False
10299
if args.output_constraint_mode != "none":
103100
assert args.disable_dynamic_prompt_cache is False
104101
assert args.disable_chunked_prefill is False

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def init_model(self, kvargs):
124124
"max_seq_length": kvargs.get("max_seq_length", 1024 * 5),
125125
"is_token_healing": kvargs.get("is_token_healing", False),
126126
"return_all_prompt_logics": self.return_all_prompt_logprobs,
127-
"use_dynamic_prompt_cache": self.use_dynamic_prompt_cache,
128127
"disable_chunked_prefill": self.disable_chunked_prefill,
129128
"data_type": kvargs.get("data_type", "float16"),
130129
"graph_max_batch_size": kvargs.get("graph_max_batch_size", 16),
@@ -231,7 +230,6 @@ def init_mtp_draft_model(self, main_kvargs: dict):
231230
"max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5),
232231
"is_token_healing": False,
233232
"return_all_prompt_logics": False,
234-
"use_dynamic_prompt_cache": self.use_dynamic_prompt_cache,
235233
"disable_chunked_prefill": self.disable_chunked_prefill,
236234
"data_type": main_kvargs.get("data_type", "float16"),
237235
"graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16),

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def init_mtp_model(args: StartArgs, kvargs, main_model):
2727
{
2828
"weight_dir": args.mtp_draft_model_dir,
2929
"max_total_token_num": main_model.mem_manager.size,
30-
"use_dynamic_prompt_cache": False,
3130
"disable_chunked_prefill": True,
3231
"mtp_mode": args.mtp_mode,
3332
"main_model": main_model,
@@ -39,7 +38,6 @@ def init_mtp_model(args: StartArgs, kvargs, main_model):
3938
{
4039
"weight_dir": args.spec_model_dir,
4140
"max_total_token_num": main_model.mem_manager.size,
42-
"use_dynamic_prompt_cache": False,
4341
"disable_chunked_prefill": True,
4442
"mtp_mode": args.mtp_mode,
4543
"main_model": main_model,

0 commit comments

Comments
 (0)