Skip to content

Commit 2a312b1

Browse files
committed
remove prefill cpu sync
1 parent 16c8c79 commit 2a312b1

File tree

15 files changed

+189
-119
lines changed

15 files changed

+189
-119
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, kvargs):
8181
self.tp_world_size_ = get_dp_world_size()
8282
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
8383

84-
self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3"
84+
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
8585

8686
self._init_datatype()
8787
self._init_config()
@@ -262,10 +262,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
262262
infer_state.b_req_idx = model_input.b_req_idx
263263
infer_state.b_seq_len = model_input.b_seq_len
264264
if model_input.is_prefill:
265-
if model_input.b_ready_cache_len is not None:
266-
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
267-
else:
268-
infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len)
265+
assert model_input.b_ready_cache_len is not None
266+
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
269267

270268
infer_state.multimodal_params = model_input.multimodal_params
271269

@@ -337,14 +335,14 @@ def _prefill(
337335
infer_state = self._create_inferstate(model_input)
338336
init_req_to_token_indexes(
339337
self.req_manager.req_to_token_indexs,
340-
model_input.b_req_idx,
341-
model_input.b_seq_len,
342-
infer_state.b_ready_cache_len,
338+
model_input.b_req_idx_cpu,
339+
model_input.b_seq_len_cpu,
340+
model_input.b_ready_cache_len_cpu,
343341
model_input.max_len_in_batch,
344342
infer_state.mem_index,
345343
)
346344

347-
infer_state.init_some_extra_state(self, model_input.input_ids)
345+
infer_state.init_some_extra_state(self, model_input)
348346
return self._context_forward(model_input.input_ids, infer_state)
349347

350348
def _decode(
@@ -369,7 +367,7 @@ def _decode(
369367
infer_state.b_seq_len,
370368
infer_state.mem_index,
371369
)
372-
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
370+
infer_state.init_some_extra_state(self, padded_model_input)
373371

374372
if self.graph.need_capture(find_graph_batch_size):
375373
infer_state.is_cuda_graph = True
@@ -390,7 +388,7 @@ def _decode(
390388
infer_state.b_seq_len,
391389
infer_state.mem_index,
392390
)
393-
infer_state.init_some_extra_state(self, model_input.input_ids)
391+
infer_state.init_some_extra_state(self, model_input)
394392
model_output = self._token_forward(model_input.input_ids, infer_state)
395393

396394
return model_output
@@ -540,15 +538,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
540538
infer_state0.b_seq_len,
541539
infer_state0.mem_index,
542540
)
543-
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
541+
infer_state0.init_some_extra_state(self, padded_model_input0)
544542
infer_state1 = self._create_inferstate(padded_model_input1, 1)
545543
copy_kv_index_to_req(
546544
self.req_manager.req_to_token_indexs,
547545
infer_state1.b_req_idx,
548546
infer_state1.b_seq_len,
549547
infer_state1.mem_index,
550548
)
551-
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
549+
infer_state1.init_some_extra_state(self, padded_model_input1)
552550

553551
if self.graph.need_capture(find_graph_batch_size):
554552
infer_state0.is_cuda_graph = True
@@ -684,25 +682,25 @@ def _check_max_len_infer(self):
684682
# 模拟最大长度进行 prefill,观察是否出现 OOM
685683
try:
686684
logger.info("begin check max_len infer")
687-
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
688-
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
689-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
690-
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
685+
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cpu")
686+
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
687+
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
688+
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
691689
b_seq_len[:] = self.batch_max_tokens
692-
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
690+
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
693691
total_token_num = self.batch_max_tokens
694-
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
692+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
695693
model_input = ModelInput(
696694
batch_size=1,
697695
total_token_num=total_token_num,
698696
max_len_in_batch=self.batch_max_tokens,
699-
input_ids=dummy_input_ids,
700-
mem_indexes=mem_indexes,
701-
b_req_idx=b_req_idx,
702-
b_seq_len=b_seq_len,
703-
b_mtp_index=b_mtp_index,
697+
input_ids_cpu=dummy_input_ids,
698+
mem_indexes_cpu=mem_indexes,
699+
b_req_idx_cpu=b_req_idx,
700+
b_seq_len_cpu=b_seq_len,
701+
b_mtp_index_cpu=b_mtp_index,
704702
is_prefill=True,
705-
b_ready_cache_len=b_ready_cache_len,
703+
b_ready_cache_len_cpu=b_ready_cache_len,
706704
)
707705
model_output = self.forward(
708706
model_input,
@@ -750,29 +748,29 @@ def _autotune_warmup(self):
750748
self.layers_num = self.autotune_layers()
751749
for input_len in tqdm(warmup_lengths, desc="warming up"):
752750
try:
753-
rand_gen = torch.Generator(device="cuda")
751+
rand_gen = torch.Generator(device="cpu")
754752
rand_gen.manual_seed(input_len)
755753
dummy_input_ids = torch.randint(
756-
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
754+
0, 10000, (input_len,), dtype=torch.int32, device="cpu", generator=rand_gen
757755
)
758-
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
759-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
760-
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
756+
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
757+
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
758+
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
761759
b_seq_len[:] = input_len
762-
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
760+
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
763761
total_token_num = input_len
764-
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
762+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
765763
model_input = ModelInput(
766764
batch_size=1,
767765
total_token_num=total_token_num,
768766
max_len_in_batch=input_len,
769-
input_ids=dummy_input_ids,
770-
mem_indexes=mem_indexes,
771-
b_req_idx=b_req_idx,
772-
b_seq_len=b_seq_len,
773-
b_mtp_index=b_mtp_index,
767+
input_ids_cpu=dummy_input_ids,
768+
mem_indexes_cpu=mem_indexes,
769+
b_req_idx_cpu=b_req_idx,
770+
b_seq_len_cpu=b_seq_len,
771+
b_mtp_index_cpu=b_mtp_index,
774772
is_prefill=True,
775-
b_ready_cache_len=b_ready_cache_len,
773+
b_ready_cache_len_cpu=b_ready_cache_len,
776774
multimodal_params=[],
777775
**self._gen_special_model_input(total_token_num),
778776
)
@@ -807,27 +805,27 @@ def _init_padded_req(self):
807805
# prefill init padding req.
808806
prefill_input_len = 1
809807
batch_size = 1
810-
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
808+
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
811809
b_req_idx = torch.tensor(
812-
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
810+
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
813811
)
814812
mem_indexes = torch.tensor(
815-
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda"
813+
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cpu"
816814
)
817-
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
818-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
815+
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cpu")
816+
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
819817
total_token_num = prefill_input_len * batch_size
820-
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
818+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
821819
model_input = ModelInput(
822820
batch_size=batch_size,
823821
total_token_num=total_token_num,
824822
max_len_in_batch=prefill_input_len,
825-
input_ids=dummy_input_ids,
826-
mem_indexes=mem_indexes,
827-
b_req_idx=b_req_idx,
828-
b_mtp_index=b_mtp_index,
829-
b_seq_len=b_seq_len,
830-
b_ready_cache_len=b_ready_cache_len,
823+
input_ids_cpu=dummy_input_ids,
824+
mem_indexes_cpu=mem_indexes,
825+
b_req_idx_cpu=b_req_idx,
826+
b_mtp_index_cpu=b_mtp_index,
827+
b_seq_len_cpu=b_seq_len,
828+
b_ready_cache_len_cpu=b_ready_cache_len,
831829
is_prefill=True,
832830
multimodal_params=[],
833831
**self._gen_special_model_input(total_token_num),

lightllm/common/basemodel/batch_objs.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,22 @@ class ModelInput:
1010
batch_size: int
1111
total_token_num: int
1212
max_len_in_batch: int
13-
input_ids: torch.Tensor
14-
b_req_idx: torch.Tensor
15-
b_mtp_index: torch.Tensor
16-
b_seq_len: torch.Tensor
13+
input_ids: torch.Tensor = None
14+
b_req_idx: torch.Tensor = None
15+
b_mtp_index: torch.Tensor = None
16+
b_seq_len: torch.Tensor = None
1717
mem_indexes: torch.Tensor = None
1818
is_prefill: bool = False
1919
b_ready_cache_len: torch.Tensor = None
2020
multimodal_params: list = field(default_factory=list)
2121

2222
# cpu 变量
23+
input_ids_cpu: torch.Tensor = None
24+
b_req_idx_cpu: torch.Tensor = None
25+
b_mtp_index_cpu: torch.Tensor = None
2326
mem_indexes_cpu: torch.Tensor = None
27+
b_seq_len_cpu: torch.Tensor = None
28+
b_ready_cache_len_cpu: torch.Tensor = None
2429
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
2530
# 的一些变量
2631
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
@@ -33,15 +38,20 @@ class ModelInput:
3338
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
3439

3540
def to_cuda(self):
36-
if self.input_ids is not None:
37-
self.input_ids = self.input_ids.cuda(non_blocking=True)
41+
# input_ids 可能不存在,通过req_to_token_indexs来获取
42+
if self.input_ids is None and self.input_ids_cpu is not None:
43+
self.input_ids = self.input_ids_cpu.cuda(non_blocking=True)
3844
if self.mem_indexes is None:
3945
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
40-
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
41-
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
42-
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
43-
if self.b_ready_cache_len is not None:
44-
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
46+
if self.b_req_idx is None:
47+
self.b_req_idx = self.b_req_idx_cpu.cuda(non_blocking=True)
48+
if self.b_seq_len is None:
49+
self.b_seq_len = self.b_seq_len_cpu.cuda(non_blocking=True)
50+
# b_ready_cache_len 只在 prefill 阶段生效
51+
if self.b_ready_cache_len_cpu is not None:
52+
self.b_ready_cache_len = self.b_ready_cache_len_cpu.cuda(non_blocking=True)
53+
if self.b_mtp_index is None:
54+
self.b_mtp_index = self.b_mtp_index_cpu.cuda(non_blocking=True)
4555

4656

4757
@dataclass

lightllm/common/basemodel/cuda_graph.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,25 +195,27 @@ def warmup(self, model):
195195
seq_len = 2
196196
total_token_num = batch_size * seq_len
197197
max_len_in_batch = self.graph_max_len_in_batch
198-
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
198+
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
199199
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
200200
b_req_idx = torch.tensor(
201-
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
201+
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
202202
)
203-
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
203+
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
204204
b_seq_len.fill_(seq_len)
205-
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205+
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
206+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
206207

207208
model_input = ModelInput(
208209
batch_size=batch_size,
209210
total_token_num=total_token_num,
210211
max_len_in_batch=max_len_in_batch,
211-
input_ids=input_ids,
212-
mem_indexes=mem_indexes,
213-
b_req_idx=b_req_idx,
214-
b_seq_len=b_seq_len,
215-
b_mtp_index=b_mtp_index,
212+
input_ids_cpu=input_ids,
213+
mem_indexes_cpu=mem_indexes,
214+
b_req_idx_cpu=b_req_idx,
215+
b_seq_len_cpu=b_seq_len,
216+
b_mtp_index_cpu=b_mtp_index,
216217
is_prefill=False,
218+
b_ready_cache_len_cpu=b_ready_cache_len,
217219
**model._gen_special_model_input(batch_size),
218220
)
219221
model_output: ModelOutput = model.forward(model_input)
@@ -251,25 +253,25 @@ def warmup_overlap(self, model):
251253
seq_len = 2
252254
total_token_num = batch_size * seq_len
253255
max_len_in_batch = self.graph_max_len_in_batch
254-
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
256+
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
255257
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
256258
b_req_idx = torch.tensor(
257-
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
259+
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
258260
)
259-
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
261+
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
260262
b_seq_len.fill_(seq_len)
261-
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
263+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
262264

263265
micro_batch = ModelInput(
264266
is_prefill=False,
265267
batch_size=batch_size,
266268
total_token_num=total_token_num,
267269
max_len_in_batch=max_len_in_batch,
268-
input_ids=input_ids,
270+
input_ids_cpu=input_ids,
269271
b_mtp_index=b_mtp_index,
270-
mem_indexes=mem_indexes,
271-
b_req_idx=b_req_idx,
272-
b_seq_len=b_seq_len,
272+
mem_indexes_cpu=mem_indexes,
273+
b_req_idx_cpu=b_req_idx,
274+
b_seq_len_cpu=b_seq_len,
273275
**model._gen_special_model_input(batch_size),
274276
)
275277
decode_batches.append(micro_batch)

lightllm/common/basemodel/infer_struct.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self):
6464
# 的输入会用到,其他模型和场景都不会用到
6565
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
6666

67-
def init_some_extra_state(self, model, input_ids: torch.Tensor):
67+
def init_some_extra_state(self, model, model_input: ModelInput):
6868
if self.is_prefill:
6969
(
7070
self.b_q_seq_len,
@@ -75,9 +75,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7575
self.max_q_seq_len,
7676
self.max_kv_seq_len,
7777
) = gen_prefill_params(
78-
input_token_num=input_ids.shape[0],
79-
b_ready_cache_len=self.b_ready_cache_len,
80-
b_seq_len=self.b_seq_len,
78+
model_input,
8179
)
8280
self.b_start_loc = self.b1_cu_q_seq_len[0:-1]
8381
else:

lightllm/common/basemodel/triton_kernel/gen_prefill_params.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import triton
44
import triton.language as tl
55

6+
from lightllm.common.basemodel.batch_objs import ModelInput
7+
68

79
@triton.jit
810
def _gen_cumsum_pad0_kernel(
@@ -80,7 +82,14 @@ def _gen_prefill_position(
8082

8183

8284
@torch.no_grad()
83-
def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor):
85+
def gen_prefill_params(model_input: ModelInput):
86+
# input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor):
87+
input_token_num = model_input.input_ids.shape[0]
88+
b_seq_len = model_input.b_seq_len
89+
b_ready_cache_len = model_input.b_ready_cache_len
90+
b_seq_len_cpu = model_input.b_seq_len_cpu
91+
b_ready_cache_len_cpu = model_input.b_ready_cache_len_cpu
92+
8493
batch_size = b_ready_cache_len.shape[0]
8594
position_ids = torch.empty((input_token_num,), dtype=torch.int32, device="cuda")
8695
assert b_ready_cache_len.shape[0] == b_seq_len.shape[0]
@@ -99,6 +108,6 @@ def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_
99108
num_stages=1,
100109
)
101110
b_kv_seq_len = b_seq_len
102-
max_q_seq_len = b_q_seq_len.max().item()
103-
max_kv_seq_len = b_kv_seq_len.max().item()
111+
max_q_seq_len = (b_seq_len_cpu - b_ready_cache_len_cpu).max()
112+
max_kv_seq_len = b_seq_len_cpu.max()
104113
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len

0 commit comments

Comments
 (0)