Skip to content

Commit 6362c4a

Browse files
committed
merge main
2 parents 432c521 + 4ca8b78 commit 6362c4a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2511
-1047
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 387 additions & 258 deletions
Large diffs are not rendered by default.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
from dataclasses import dataclass, field
3+
from typing import Optional
4+
5+
6+
@dataclass
7+
class ModelInput:
8+
# 通用变量
9+
batch_size: int
10+
total_token_num: int
11+
max_len_in_batch: int
12+
input_ids: torch.Tensor
13+
mem_indexes: torch.Tensor
14+
b_req_idx: torch.Tensor
15+
b_seq_len: torch.Tensor
16+
is_prefill: bool = False
17+
b_ready_cache_len: torch.Tensor = None
18+
multimodal_params: list = field(default_factory=list)
19+
20+
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
21+
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。
22+
23+
# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
24+
# 的 draft 模型的输入
25+
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
26+
27+
28+
@dataclass
29+
class ModelOutput:
30+
# 通用变量
31+
logits: torch.Tensor
32+
33+
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
34+
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。
35+
36+
# deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model
37+
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
38+
# 输入
39+
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None

lightllm/common/basemodel/cuda_graph.py

Lines changed: 133 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import os
22
import torch
33
import copy
4+
import bisect
5+
from typing import Optional
46
from lightllm.utils.log_utils import init_logger
57
from lightllm.utils.envs_utils import get_env_start_args
68
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
7-
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch
9+
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
10+
from .infer_struct import InferStateInfo
11+
812

913
logger = init_logger(__name__)
1014

@@ -17,15 +21,48 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
1721
self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
1822
self.max_batch_size = max_batch_size
1923
self.graph_max_len_in_batch = max_len_in_batch
20-
self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap
24+
self.args = get_env_start_args()
25+
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
26+
27+
# gen cuda graph batch_sizes
28+
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
29+
# and [graph_split_batch_size + graph_grow_step_size,
30+
# graph_split_batch_size + 2 * graph_grow_step_size, ..., self.max_batch_size]
31+
graph_split_batch_size = self.args.graph_split_batch_size
32+
max_batch_size = self.max_batch_size
33+
graph_grow_step_size = self.args.graph_grow_step_size
34+
35+
batch_sizes = [i for i in range(1, graph_split_batch_size + 1)]
36+
for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size):
37+
batch_sizes.append(_batch_size)
38+
39+
batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size]))
40+
batch_sizes.append(max_batch_size)
41+
batch_sizes.sort()
42+
43+
self.cuda_graph_batch_sizes = batch_sizes
44+
assert batch_sizes[-1] == self.max_batch_size
45+
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")
2146

2247
def can_run(self, batch_size, max_len_in_batch):
2348
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch
2449

2550
def need_capture(self, batch_size):
26-
return batch_size not in self.graph
51+
find_batch_size = self.find_closest_graph_batch_size(batch_size)
52+
if find_batch_size is not None:
53+
return find_batch_size not in self.graph
54+
else:
55+
assert False, "dead code"
2756

28-
def _capture_decode(self, decode_func, input_ids, infer_state):
57+
def find_closest_graph_batch_size(self, batch_size):
58+
index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size)
59+
if index < len(self.cuda_graph_batch_sizes):
60+
find_batch_size = self.cuda_graph_batch_sizes[index]
61+
return find_batch_size
62+
else:
63+
return None
64+
65+
def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo):
2966
dist_group: CustomProcessGroup = infer_state.dist_group
3067
graph_obj = torch.cuda.CUDAGraph()
3168
batch_size = input_ids.shape[0]
@@ -46,12 +83,19 @@ def _capture_decode(self, decode_func, input_ids, infer_state):
4683

4784
with lightllm_capture_graph(dist_group):
4885
with torch.cuda.graph(graph_obj, pool=self.mempool):
49-
predict_logics = decode_func(input_ids, infer_state)
50-
self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics)
86+
model_output = decode_func(input_ids, infer_state)
87+
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
5188
graph_obj.replay()
52-
return predict_logics
89+
return model_output
5390

54-
def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids1, infer_state1):
91+
def _capture_decode_overlap(
92+
self,
93+
decode_func,
94+
input_ids: torch.Tensor,
95+
infer_state: InferStateInfo,
96+
input_ids1: torch.Tensor,
97+
infer_state1: InferStateInfo,
98+
):
5599
dist_group: CustomProcessGroup = infer_state.dist_group
56100
dist_group1 = infer_state1.dist_group
57101
graph_obj = torch.cuda.CUDAGraph()
@@ -68,20 +112,27 @@ def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids
68112
with lightllm_capture_graph(dist_group1):
69113
with lightllm_capture_graph(dist_group):
70114
with torch.cuda.graph(graph_obj, pool=self.mempool):
71-
predict_logics, predict_logics1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
115+
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
72116
self.graph[batch_size] = (
73117
graph_obj,
74118
input_ids,
75119
infer_state,
76120
input_ids1,
77121
infer_state1,
78-
predict_logics,
79-
predict_logics1,
122+
model_output,
123+
model_output1,
80124
)
81125
graph_obj.replay()
82-
return predict_logics, predict_logics1
126+
return model_output, model_output1
83127

84-
def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, infer_state1=None):
128+
def capture_decode(
129+
self,
130+
decode_func,
131+
input_ids: torch.Tensor,
132+
infer_state: InferStateInfo,
133+
input_ids1: Optional[torch.Tensor] = None,
134+
infer_state1: Optional[torch.Tensor] = None,
135+
):
85136
"""
86137
Capture the cuda graph for the decoding stage.
87138
input_ids1 and infer_state1 is used for the overlap.
@@ -92,31 +143,37 @@ def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, i
92143
assert input_ids1 is None and infer_state1 is None
93144
return self._capture_decode(decode_func, input_ids, infer_state)
94145

95-
def _replay(self, input_ids, infer_state):
146+
def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo):
96147
batch_size = input_ids.shape[0]
97-
graph_obj, graph_input_ids, graph_infer_state, graph_predict_logics = self.graph[batch_size]
148+
graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size]
98149
graph_input_ids.copy_(input_ids)
99150
graph_infer_state.copy_for_cuda_graph(infer_state)
100151
graph_obj.replay()
101-
return graph_predict_logics
152+
return graph_output
102153

103-
def _replay_overlap(self, input_ids, infer_state, input_ids1, infer_state1):
154+
def _replay_overlap(
155+
self,
156+
input_ids: torch.Tensor,
157+
infer_state: InferStateInfo,
158+
input_ids1: torch.Tensor,
159+
infer_state1: InferStateInfo,
160+
):
104161
batch_size = input_ids.shape[0]
105162
(
106163
graph_obj,
107164
graph_input_ids,
108165
graph_infer_state,
109166
graph_input_ids1,
110167
graph_infer_state1,
111-
graph_predict_logics,
112-
graph_predict_logics1,
168+
graph_model_output,
169+
graph_model_output1,
113170
) = self.graph[batch_size]
114171
graph_input_ids.copy_(input_ids)
115172
graph_infer_state.copy_for_cuda_graph(infer_state)
116173
graph_input_ids1.copy_(input_ids1)
117174
graph_infer_state1.copy_for_cuda_graph(infer_state1)
118175
graph_obj.replay()
119-
return graph_predict_logics, graph_predict_logics1
176+
return graph_model_output, graph_model_output1
120177

121178
def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
122179
if self.enable_decode_microbatch_overlap:
@@ -128,59 +185,50 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
128185
@torch.no_grad()
129186
def warmup(self, model):
130187
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
131-
for batch_size in range(self.max_batch_size, self.max_batch_size - 1, -1):
132-
# dummy prefill
133-
prefill_input_len = 1
134-
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
188+
# for typing easy
189+
from .basemodel import TpPartBaseModel
190+
191+
model: TpPartBaseModel = model
192+
193+
# decode cuda graph init
194+
for batch_size in self.cuda_graph_batch_sizes[::-1]:
195+
seq_len = 2
196+
total_token_num = batch_size * seq_len
197+
max_len_in_batch = self.graph_max_len_in_batch
198+
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
199+
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
135200
b_req_idx = torch.tensor(
136-
[model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
137-
)
138-
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda()
139-
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
140-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
141-
total_token_num = prefill_input_len * batch_size
142-
logics = model.forward(
143-
batch_size,
144-
total_token_num,
145-
prefill_input_len,
146-
dummy_input_ids,
147-
mem_indexes,
148-
b_req_idx,
149-
b_seq_len,
150-
b_ready_cache_len=b_ready_cache_len,
151-
is_prefill=True,
152-
multimodal_params=[],
201+
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
153202
)
154-
mem_indexes = None
155-
prob_out = torch.softmax(logics, dim=-1)
156-
logics = None
157-
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
158-
prob_out = None
159-
predict_ids = predict_ids.detach().cpu().numpy()
160-
torch.cuda.empty_cache()
203+
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204+
b_seq_len.fill_(seq_len)
161205

162-
# dummy decoding, capture the cudagraph
163-
total_token_num += batch_size
164-
b_seq_len += 1
165-
mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda()
166-
logics = model.forward(
167-
batch_size,
168-
total_token_num,
169-
prefill_input_len + 1,
170-
torch.from_numpy(predict_ids).cuda().reshape(-1),
171-
mem_indexes,
172-
b_req_idx,
173-
b_seq_len,
206+
model_input = ModelInput(
207+
batch_size=batch_size,
208+
total_token_num=total_token_num,
209+
max_len_in_batch=max_len_in_batch,
210+
input_ids=input_ids,
211+
mem_indexes=mem_indexes,
212+
b_req_idx=b_req_idx,
213+
b_seq_len=b_seq_len,
174214
is_prefill=False,
215+
**model._gen_special_model_input(batch_size),
175216
)
176-
mem_indexes = None
217+
model_output: ModelOutput = model.forward(model_input)
218+
del model_output
219+
del input_ids
220+
del mem_indexes
221+
del b_req_idx
222+
del b_seq_len
223+
177224
model.mem_manager.free_all()
178225
model.req_manager.free_all()
179226
# release local tensors
180227
for var_name, var_value in list(locals().items()):
181228
if isinstance(var_value, torch.Tensor):
182229
del locals()[var_name]
183230
torch.cuda.empty_cache()
231+
184232
logger.info(
185233
f"Capture cudagraph success, batch_size <={self.max_batch_size} "
186234
f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph."
@@ -189,64 +237,52 @@ def warmup(self, model):
189237
@torch.no_grad()
190238
def warmup_overlap(self, model):
191239
logger.info("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it.")
192-
for batch_size in range(self.max_batch_size, 0, -1):
240+
# for typing easy
241+
from .basemodel import TpPartBaseModel
242+
243+
model: TpPartBaseModel = model
244+
245+
for batch_size in self.cuda_graph_batch_sizes[::-1]:
193246
decode_batches = []
194247
for micro_batch_index in [0, 1]:
195-
# dummy prefill
196-
prefill_input_len = 1
197-
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
248+
# dummy decoding, capture the cudagraph
249+
seq_len = 2
250+
total_token_num = batch_size * seq_len
251+
max_len_in_batch = self.graph_max_len_in_batch
252+
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
253+
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
198254
b_req_idx = torch.tensor(
199-
[model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
200-
)
201-
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda()
202-
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
203-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
204-
total_token_num = prefill_input_len * batch_size
205-
logics = model.forward(
206-
batch_size,
207-
total_token_num,
208-
prefill_input_len,
209-
dummy_input_ids,
210-
mem_indexes,
211-
b_req_idx,
212-
b_seq_len,
213-
b_ready_cache_len=b_ready_cache_len,
214-
is_prefill=True,
215-
multimodal_params=[],
255+
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
216256
)
217-
mem_indexes = None
218-
prob_out = torch.softmax(logics, dim=-1)
219-
logics = None
220-
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
221-
prob_out = None
222-
predict_ids = predict_ids.detach().cpu().numpy()
223-
torch.cuda.empty_cache()
224-
225-
# dummy decoding, capture the cudagraph
226-
total_token_num += batch_size
227-
b_seq_len += 1
228-
mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda()
257+
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
258+
b_seq_len.fill_(seq_len)
229259

230-
micro_batch = DecodeMicroBatch(
260+
micro_batch = ModelInput(
261+
is_prefill=False,
231262
batch_size=batch_size,
232263
total_token_num=total_token_num,
233-
max_len_in_batch=prefill_input_len + 1,
234-
input_ids=torch.from_numpy(predict_ids).cuda().reshape(-1),
264+
max_len_in_batch=max_len_in_batch,
265+
input_ids=input_ids,
235266
mem_indexes=mem_indexes,
236267
b_req_idx=b_req_idx,
237268
b_seq_len=b_seq_len,
269+
**model._gen_special_model_input(batch_size),
238270
)
239271
decode_batches.append(micro_batch)
272+
del micro_batch
240273

241274
for var_name, var_value in list(locals().items()):
242275
if isinstance(var_value, torch.Tensor):
243276
del locals()[var_name]
244277
torch.cuda.empty_cache()
278+
245279
_, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1])
246280

247281
model.mem_manager.free_all()
248282
model.req_manager.free_all()
249283

284+
del decode_batches
285+
250286
# release local tensors
251287
for var_name, var_value in list(locals().items()):
252288
if isinstance(var_value, torch.Tensor):

0 commit comments

Comments
 (0)