Skip to content

Commit 6d23241

Browse files
shihaobaihiworldwzj
authored andcommitted
bugfix:qwen3 fa3 inferstruct init, add b_prefill_start_loc for init_req_to_token_indexes (#1081)
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 7e737fa commit 6d23241

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,16 @@ def _prefill(
352352
model_input: ModelInput,
353353
):
354354
infer_state = self._create_inferstate(model_input)
355-
infer_state.init_some_extra_state(self, model_input.input_ids)
356355
init_req_to_token_indexes(
357356
req_to_token_indexs=self.req_manager.req_to_token_indexs,
358357
b_req_idx=infer_state.b_req_idx,
359358
b_seq_len=infer_state.b_seq_len,
360359
b_ready_cache_len=infer_state.b_ready_cache_len,
361-
b_start_loc=infer_state.b_start_loc,
360+
b_start_loc=model_input.b_prefill_start_loc,
362361
alloc_mem_index=infer_state.mem_index,
363362
max_q_seq_len=infer_state.max_q_seq_len,
364363
)
364+
infer_state.init_some_extra_state(self, model_input.input_ids)
365365
return self._context_forward(model_input.input_ids, infer_state)
366366

367367
def _decode(
@@ -491,28 +491,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
491491
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
492492

493493
infer_state0 = self._create_inferstate(model_input0, 0)
494-
infer_state0.init_some_extra_state(self, input_ids0)
495494
init_req_to_token_indexes(
496495
req_to_token_indexs=self.req_manager.req_to_token_indexs,
497496
b_req_idx=infer_state0.b_req_idx,
498497
b_seq_len=infer_state0.b_seq_len,
499498
b_ready_cache_len=infer_state0.b_ready_cache_len,
500-
b_start_loc=infer_state0.b_start_loc,
499+
b_start_loc=model_input0.b_prefill_start_loc,
501500
alloc_mem_index=infer_state0.mem_index,
502501
max_q_seq_len=infer_state0.max_q_seq_len,
503502
)
503+
infer_state0.init_some_extra_state(self, input_ids0)
504504

505505
infer_state1 = self._create_inferstate(model_input1, 1)
506-
infer_state1.init_some_extra_state(self, input_ids1)
507506
init_req_to_token_indexes(
508507
req_to_token_indexs=self.req_manager.req_to_token_indexs,
509508
b_req_idx=infer_state1.b_req_idx,
510509
b_seq_len=infer_state1.b_seq_len,
511510
b_ready_cache_len=infer_state1.b_ready_cache_len,
512-
b_start_loc=infer_state1.b_start_loc,
511+
b_start_loc=model_input1.b_prefill_start_loc,
513512
alloc_mem_index=infer_state1.mem_index,
514513
max_q_seq_len=infer_state1.max_q_seq_len,
515514
)
515+
infer_state1.init_some_extra_state(self, input_ids1)
516516

517517
model_output0, model_output1 = self._overlap_tpsp_context_forward(
518518
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
@@ -713,6 +713,7 @@ def _check_max_len_infer(self):
713713
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
714714
b_seq_len[:] = self.batch_max_tokens
715715
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
716+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
716717
total_token_num = self.batch_max_tokens
717718
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
718719
model_input = ModelInput(
@@ -730,6 +731,7 @@ def _check_max_len_infer(self):
730731
b_mtp_index=b_mtp_index,
731732
is_prefill=True,
732733
b_ready_cache_len=b_ready_cache_len,
734+
b_prefill_start_loc=b_prefill_start_loc,
733735
)
734736
model_output = self.forward(
735737
model_input,
@@ -787,6 +789,7 @@ def _autotune_warmup(self):
787789
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
788790
b_seq_len[:] = input_len
789791
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
792+
b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda")
790793
total_token_num = input_len
791794
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
792795
model_input = ModelInput(
@@ -804,6 +807,7 @@ def _autotune_warmup(self):
804807
b_mtp_index=b_mtp_index,
805808
is_prefill=True,
806809
b_ready_cache_len=b_ready_cache_len,
810+
b_prefill_start_loc=b_prefill_start_loc,
807811
multimodal_params=[],
808812
**self._gen_special_model_input(total_token_num),
809813
)
@@ -847,6 +851,8 @@ def _init_padded_req(self):
847851
)
848852
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
849853
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
854+
b_q_seq_len = b_seq_len - b_ready_cache_len
855+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
850856
total_token_num = prefill_input_len * batch_size
851857
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
852858
model_input = ModelInput(
@@ -863,6 +869,7 @@ def _init_padded_req(self):
863869
b_mtp_index=b_mtp_index,
864870
b_seq_len=b_seq_len,
865871
b_ready_cache_len=b_ready_cache_len,
872+
b_prefill_start_loc=b_prefill_start_loc,
866873
is_prefill=True,
867874
multimodal_params=[],
868875
**self._gen_special_model_input(total_token_num),

lightllm/common/basemodel/batch_objs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ModelInput:
2424
mem_indexes: torch.Tensor = None
2525
is_prefill: bool = False
2626
b_ready_cache_len: torch.Tensor = None
27+
b_prefill_start_loc: torch.Tensor = None
2728
multimodal_params: list = field(default_factory=list)
2829

2930
# cpu 变量
@@ -49,6 +50,8 @@ def to_cuda(self):
4950
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
5051
if self.b_ready_cache_len is not None:
5152
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
53+
if self.b_prefill_start_loc is not None:
54+
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
5255

5356

5457
@dataclass

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def padded_prepare_prefill_inputs(
8080
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
8181
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
8282
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
83+
b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu")
84+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
8385

8486
# dynamic prompt cache 准备 token
8587
g_infer_state_lock.acquire()
@@ -110,6 +112,7 @@ def padded_prepare_prefill_inputs(
110112
b_mtp_index=b_mtp_index,
111113
b_seq_len=b_seq_len,
112114
b_ready_cache_len=b_ready_cache_len,
115+
b_prefill_start_loc=b_prefill_start_loc,
113116
is_prefill=True,
114117
b_prefill_has_output_cpu=b_prefill_has_output,
115118
)

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def prepare_prefill_inputs(
5757
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
5858
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
5959
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
60+
b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu")
61+
b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
6062

6163
# dynamic prompt cache 准备 token
6264
g_infer_state_lock.acquire()
@@ -78,6 +80,7 @@ def prepare_prefill_inputs(
7880
b_mtp_index=b_mtp_index,
7981
b_seq_len=b_seq_len,
8082
b_ready_cache_len=b_ready_cache_len,
83+
b_prefill_start_loc=b_prefill_start_loc,
8184
is_prefill=True,
8285
b_prefill_has_output_cpu=b_prefill_has_output,
8386
prefix_total_token_num=prefix_total_token_num,

0 commit comments

Comments
 (0)