Skip to content

Commit 8cc2325

Browse files
committed
back the infer_struct
1 parent 6cd8c56 commit 8cc2325

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _prefill(
341341
infer_state.mem_index,
342342
)
343343

344-
infer_state.init_some_extra_state(self, model_input)
344+
infer_state.init_some_extra_state(self, model_input.input_ids)
345345
return self._context_forward(model_input.input_ids, infer_state)
346346

347347
def _decode(
@@ -365,7 +365,7 @@ def _decode(
365365
infer_state.b_seq_len,
366366
infer_state.mem_index,
367367
)
368-
infer_state.init_some_extra_state(self, padded_model_input)
368+
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
369369

370370
if self.graph.need_capture(find_graph_batch_size):
371371
infer_state.is_cuda_graph = True
@@ -386,7 +386,7 @@ def _decode(
386386
infer_state.b_seq_len,
387387
infer_state.mem_index,
388388
)
389-
infer_state.init_some_extra_state(self, model_input)
389+
infer_state.init_some_extra_state(self, model_input.input_ids)
390390
model_output = self._token_forward(model_input.input_ids, infer_state)
391391

392392
return model_output

lightllm/common/basemodel/infer_struct.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self):
6565
# 的输入会用到,其他模型和场景都不会用到
6666
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
6767

68-
def init_some_extra_state(self, model, model_input: ModelInput):
68+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
6969
if self.is_prefill:
7070
(
7171
self.b_q_seq_len,
@@ -76,7 +76,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
7676
self.max_q_seq_len,
7777
self.max_kv_seq_len,
7878
) = gen_prefill_params(
79-
input_token_num=model_input.input_ids.shape[0],
79+
input_token_num=input_ids.shape[0],
8080
b_ready_cache_len=self.b_ready_cache_len,
8181
b_seq_len=self.b_seq_len,
8282
)
@@ -88,10 +88,10 @@ def init_some_extra_state(self, model, model_input: ModelInput):
8888
self.b_kv_seq_len,
8989
self.b1_cu_kv_seq_len,
9090
self.position_ids,
91+
self.max_q_seq_len,
92+
self.max_kv_seq_len,
9193
) = gen_decode_params(self.b_seq_len)
9294
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]
93-
self.max_q_seq_len = 1
94-
self.max_kv_seq_len = model_input.max_len_in_batch
9595

9696
def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
9797
for attr_name, attr_value in vars(new_infer_state).items():

lightllm/common/basemodel/triton_kernel/gen_decode_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ def gen_decode_params(b_seq_len: torch.Tensor):
1010
position_ids = b_seq_len - 1
1111
b_q_seq_len = torch.ones_like(b_seq_len)
1212
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
13+
max_q_seq_len = b_q_seq_len.max().item()
14+
max_kv_seq_len = b_kv_seq_len.max().item()
1315

14-
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids
16+
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
2424
]
2525
return cls._shared_page_table_buffer
2626

27-
def init_some_extra_state(self, model, model_input: ModelInput):
28-
super().init_some_extra_state(model, model_input)
27+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
28+
super().init_some_extra_state(model, input_ids)
2929
if self.is_prefill:
3030
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3131
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
3232
self.page_table = torch.empty(
33-
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=model_input.input_ids.device
33+
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
3434
)
3535
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
3636
else:
@@ -47,7 +47,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
4747
].reshape(self.batch_size, model.graph_max_len_in_batch)
4848
else:
4949
self.page_table = torch.empty(
50-
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=model_input.input_ids.device
50+
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
5151
)
5252

5353
self.page_table[:, :max_seq_len_k].copy_(
@@ -58,7 +58,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
5858

5959
if "offline_calibration_fp8kv" in model.mode:
6060
if self.is_prefill:
61-
device = model_input.input_ids.device
61+
device = input_ids.device
6262
# q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化
6363
self.q_scale = torch.empty(
6464
(self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device
@@ -78,7 +78,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
7878
else torch.ones(
7979
(self.mem_manager.layer_num, self.batch_size, head_num),
8080
dtype=torch.float32,
81-
device=model_input.input_ids.device,
81+
device=input_ids.device,
8282
)
8383
)
8484
self.v_descale = (
@@ -89,7 +89,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
8989
else torch.ones(
9090
(self.mem_manager.layer_num, self.batch_size, head_num),
9191
dtype=torch.float32,
92-
device=model_input.input_ids.device,
92+
device=input_ids.device,
9393
)
9494
)
9595
return

0 commit comments

Comments
 (0)