Skip to content

Commit 071161a

Browse files
bugfix:qwen3 fa3 inferstruct init, add b_prefill_start_loc for init_req_to_token_indexes (#1081)
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent db1b64c commit 071161a

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,16 @@ def _prefill(
343343
model_input: ModelInput,
344344
):
345345
infer_state = self._create_inferstate(model_input)
346-
infer_state.init_some_extra_state(self, model_input.input_ids)
347346
init_req_to_token_indexes(
348347
req_to_token_indexs=self.req_manager.req_to_token_indexs,
349348
b_req_idx=infer_state.b_req_idx,
350349
b_seq_len=infer_state.b_seq_len,
351350
b_ready_cache_len=infer_state.b_ready_cache_len,
352-
b_start_loc=infer_state.b_start_loc,
351+
b_start_loc=model_input.b_prefill_start_loc,
353352
alloc_mem_index=infer_state.mem_index,
354353
max_q_seq_len=infer_state.max_q_seq_len,
355354
)
355+
infer_state.init_some_extra_state(self, model_input.input_ids)
356356
return self._context_forward(model_input.input_ids, infer_state)
357357

358358
def _decode(
@@ -482,28 +482,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
482482
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
483483

484484
infer_state0 = self._create_inferstate(model_input0, 0)
485-
infer_state0.init_some_extra_state(self, input_ids0)
486485
init_req_to_token_indexes(
487486
req_to_token_indexs=self.req_manager.req_to_token_indexs,
488487
b_req_idx=infer_state0.b_req_idx,
489488
b_seq_len=infer_state0.b_seq_len,
490489
b_ready_cache_len=infer_state0.b_ready_cache_len,
491-
b_start_loc=infer_state0.b_start_loc,
490+
b_start_loc=model_input0.b_prefill_start_loc,
492491
alloc_mem_index=infer_state0.mem_index,
493492
max_q_seq_len=infer_state0.max_q_seq_len,
494493
)
494+
infer_state0.init_some_extra_state(self, input_ids0)
495495

496496
infer_state1 = self._create_inferstate(model_input1, 1)
497-
infer_state1.init_some_extra_state(self, input_ids1)
498497
init_req_to_token_indexes(
499498
req_to_token_indexs=self.req_manager.req_to_token_indexs,
500499
b_req_idx=infer_state1.b_req_idx,
501500
b_seq_len=infer_state1.b_seq_len,
502501
b_ready_cache_len=infer_state1.b_ready_cache_len,
503-
b_start_loc=infer_state1.b_start_loc,
502+
b_start_loc=model_input1.b_prefill_start_loc,
504503
alloc_mem_index=infer_state1.mem_index,
505504
max_q_seq_len=infer_state1.max_q_seq_len,
506505
)
506+
infer_state1.init_some_extra_state(self, input_ids1)
507507

508508
model_output0, model_output1 = self._overlap_tpsp_context_forward(
509509
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
@@ -704,6 +704,7 @@ def _check_max_len_infer(self):
704704
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
705705
b_seq_len[:] = self.batch_max_tokens
706706
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
707+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
707708
total_token_num = self.batch_max_tokens
708709
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
709710
model_input = ModelInput(
@@ -721,6 +722,7 @@ def _check_max_len_infer(self):
721722
b_mtp_index=b_mtp_index,
722723
is_prefill=True,
723724
b_ready_cache_len=b_ready_cache_len,
725+
b_prefill_start_loc=b_prefill_start_loc,
724726
)
725727
model_output = self.forward(
726728
model_input,
@@ -778,6 +780,7 @@ def _autotune_warmup(self):
778780
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
779781
b_seq_len[:] = input_len
780782
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
783+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
781784
total_token_num = input_len
782785
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
783786
model_input = ModelInput(
@@ -795,6 +798,7 @@ def _autotune_warmup(self):
795798
b_mtp_index=b_mtp_index,
796799
is_prefill=True,
797800
b_ready_cache_len=b_ready_cache_len,
801+
b_prefill_start_loc=b_prefill_start_loc,
798802
multimodal_params=[],
799803
**self._gen_special_model_input(total_token_num),
800804
)
@@ -838,6 +842,8 @@ def _init_padded_req(self):
838842
)
839843
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
840844
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
845+
b_q_seq_len = b_seq_len - b_ready_cache_len
846+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
841847
total_token_num = prefill_input_len * batch_size
842848
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
843849
model_input = ModelInput(
@@ -854,6 +860,7 @@ def _init_padded_req(self):
854860
b_mtp_index=b_mtp_index,
855861
b_seq_len=b_seq_len,
856862
b_ready_cache_len=b_ready_cache_len,
863+
b_prefill_start_loc=b_prefill_start_loc,
857864
is_prefill=True,
858865
multimodal_params=[],
859866
**self._gen_special_model_input(total_token_num),

lightllm/common/basemodel/batch_objs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ModelInput:
2424
mem_indexes: torch.Tensor = None
2525
is_prefill: bool = False
2626
b_ready_cache_len: torch.Tensor = None
27+
b_prefill_start_loc: torch.Tensor = None
2728
multimodal_params: list = field(default_factory=list)
2829

2930
# cpu 变量
@@ -49,6 +50,8 @@ def to_cuda(self):
4950
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
5051
if self.b_ready_cache_len is not None:
5152
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
53+
if self.b_prefill_start_loc is not None:
54+
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
5255

5356

5457
@dataclass

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def padded_prepare_prefill_inputs(
8080
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
8181
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
8282
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
83+
b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu")
84+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
8385

8486
# dynamic prompt cache 准备 token
8587
g_infer_state_lock.acquire()
@@ -110,6 +112,7 @@ def padded_prepare_prefill_inputs(
110112
b_mtp_index=b_mtp_index,
111113
b_seq_len=b_seq_len,
112114
b_ready_cache_len=b_ready_cache_len,
115+
b_prefill_start_loc=b_prefill_start_loc,
113116
is_prefill=True,
114117
b_prefill_has_output_cpu=b_prefill_has_output,
115118
)

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def prepare_prefill_inputs(
5757
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
5858
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
5959
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
60+
b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu")
61+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
6062

6163
# dynamic prompt cache 准备 token
6264
g_infer_state_lock.acquire()
@@ -78,6 +80,7 @@ def prepare_prefill_inputs(
7880
b_mtp_index=b_mtp_index,
7981
b_seq_len=b_seq_len,
8082
b_ready_cache_len=b_ready_cache_len,
83+
b_prefill_start_loc=b_prefill_start_loc,
8184
is_prefill=True,
8285
b_prefill_has_output_cpu=b_prefill_has_output,
8386
prefix_total_token_num=prefix_total_token_num,

0 commit comments

Comments
 (0)