11import os
22import torch
33import copy
4+ import bisect
5+ from typing import Optional
46from lightllm .utils .log_utils import init_logger
57from lightllm .utils .envs_utils import get_env_start_args
68from lightllm .distributed import dist_group_manager , lightllm_capture_graph , CustomProcessGroup
7- from lightllm .common .basemodel .microbatch_overlap_objs import DecodeMicroBatch
9+ from lightllm .common .basemodel .batch_objs import ModelInput , ModelOutput
10+ from .infer_struct import InferStateInfo
11+
812
913logger = init_logger (__name__ )
1014
@@ -17,15 +21,48 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
1721 self .mempool = torch .cuda .graph_pool_handle () if torch .cuda .is_available () else None
1822 self .max_batch_size = max_batch_size
1923 self .graph_max_len_in_batch = max_len_in_batch
20- self .enable_decode_microbatch_overlap = get_env_start_args ().enable_decode_microbatch_overlap
24+ self .args = get_env_start_args ()
25+ self .enable_decode_microbatch_overlap = self .args .enable_decode_microbatch_overlap
26+
27+ # gen cuda graph batch_sizes
28+ # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
29+ # and [graph_split_batch_size + graph_grow_step_size,
30+ # graph_split_batch_size + 2 * graph_grow_step_size, ..., self.max_batch_size]
31+ graph_split_batch_size = self .args .graph_split_batch_size
32+ max_batch_size = self .max_batch_size
33+ graph_grow_step_size = self .args .graph_grow_step_size
34+
35+ batch_sizes = [i for i in range (1 , graph_split_batch_size + 1 )]
36+ for _batch_size in range (graph_split_batch_size + graph_grow_step_size , max_batch_size , graph_grow_step_size ):
37+ batch_sizes .append (_batch_size )
38+
39+ batch_sizes = list (set ([e for e in batch_sizes if e < max_batch_size ]))
40+ batch_sizes .append (max_batch_size )
41+ batch_sizes .sort ()
42+
43+ self .cuda_graph_batch_sizes = batch_sizes
44+ assert batch_sizes [- 1 ] == self .max_batch_size
45+ logger .info (f"cuda graph batch_sizes: { self .cuda_graph_batch_sizes } " )
2146
2247 def can_run (self , batch_size , max_len_in_batch ):
2348 return batch_size <= self .max_batch_size and max_len_in_batch <= self .graph_max_len_in_batch
2449
2550 def need_capture (self , batch_size ):
26- return batch_size not in self .graph
51+ find_batch_size = self .find_closest_graph_batch_size (batch_size )
52+ if find_batch_size is not None :
53+ return find_batch_size not in self .graph
54+ else :
55+ assert False , "dead code"
2756
28- def _capture_decode (self , decode_func , input_ids , infer_state ):
57+ def find_closest_graph_batch_size (self , batch_size ):
58+ index = bisect .bisect_left (self .cuda_graph_batch_sizes , batch_size )
59+ if index < len (self .cuda_graph_batch_sizes ):
60+ find_batch_size = self .cuda_graph_batch_sizes [index ]
61+ return find_batch_size
62+ else :
63+ return None
64+
65+ def _capture_decode (self , decode_func , input_ids : torch .Tensor , infer_state : InferStateInfo ):
2966 dist_group : CustomProcessGroup = infer_state .dist_group
3067 graph_obj = torch .cuda .CUDAGraph ()
3168 batch_size = input_ids .shape [0 ]
@@ -46,12 +83,19 @@ def _capture_decode(self, decode_func, input_ids, infer_state):
4683
4784 with lightllm_capture_graph (dist_group ):
4885 with torch .cuda .graph (graph_obj , pool = self .mempool ):
49- predict_logics = decode_func (input_ids , infer_state )
50- self .graph [batch_size ] = (graph_obj , input_ids , infer_state , predict_logics )
86+ model_output = decode_func (input_ids , infer_state )
87+ self .graph [batch_size ] = (graph_obj , input_ids , infer_state , model_output )
5188 graph_obj .replay ()
52- return predict_logics
89+ return model_output
5390
54- def _capture_decode_overlap (self , decode_func , input_ids , infer_state , input_ids1 , infer_state1 ):
91+ def _capture_decode_overlap (
92+ self ,
93+ decode_func ,
94+ input_ids : torch .Tensor ,
95+ infer_state : InferStateInfo ,
96+ input_ids1 : torch .Tensor ,
97+ infer_state1 : InferStateInfo ,
98+ ):
5599 dist_group : CustomProcessGroup = infer_state .dist_group
56100 dist_group1 = infer_state1 .dist_group
57101 graph_obj = torch .cuda .CUDAGraph ()
@@ -68,20 +112,27 @@ def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids
68112 with lightllm_capture_graph (dist_group1 ):
69113 with lightllm_capture_graph (dist_group ):
70114 with torch .cuda .graph (graph_obj , pool = self .mempool ):
71- predict_logics , predict_logics1 = decode_func (input_ids , infer_state , input_ids1 , infer_state1 )
115+ model_output , model_output1 = decode_func (input_ids , infer_state , input_ids1 , infer_state1 )
72116 self .graph [batch_size ] = (
73117 graph_obj ,
74118 input_ids ,
75119 infer_state ,
76120 input_ids1 ,
77121 infer_state1 ,
78- predict_logics ,
79- predict_logics1 ,
122+ model_output ,
123+ model_output1 ,
80124 )
81125 graph_obj .replay ()
82- return predict_logics , predict_logics1
126+ return model_output , model_output1
83127
84- def capture_decode (self , decode_func , input_ids , infer_state , input_ids1 = None , infer_state1 = None ):
128+ def capture_decode (
129+ self ,
130+ decode_func ,
131+ input_ids : torch .Tensor ,
132+ infer_state : InferStateInfo ,
133+ input_ids1 : Optional [torch .Tensor ] = None ,
134+ infer_state1 : Optional [torch .Tensor ] = None ,
135+ ):
85136 """
86137 Capture the cuda graph for the decoding stage.
87138 input_ids1 and infer_state1 is used for the overlap.
@@ -92,31 +143,37 @@ def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, i
92143 assert input_ids1 is None and infer_state1 is None
93144 return self ._capture_decode (decode_func , input_ids , infer_state )
94145
95- def _replay (self , input_ids , infer_state ):
146+ def _replay (self , input_ids : torch . Tensor , infer_state : InferStateInfo ):
96147 batch_size = input_ids .shape [0 ]
97- graph_obj , graph_input_ids , graph_infer_state , graph_predict_logics = self .graph [batch_size ]
148+ graph_obj , graph_input_ids , graph_infer_state , graph_output = self .graph [batch_size ]
98149 graph_input_ids .copy_ (input_ids )
99150 graph_infer_state .copy_for_cuda_graph (infer_state )
100151 graph_obj .replay ()
101- return graph_predict_logics
152+ return graph_output
102153
103- def _replay_overlap (self , input_ids , infer_state , input_ids1 , infer_state1 ):
154+ def _replay_overlap (
155+ self ,
156+ input_ids : torch .Tensor ,
157+ infer_state : InferStateInfo ,
158+ input_ids1 : torch .Tensor ,
159+ infer_state1 : InferStateInfo ,
160+ ):
104161 batch_size = input_ids .shape [0 ]
105162 (
106163 graph_obj ,
107164 graph_input_ids ,
108165 graph_infer_state ,
109166 graph_input_ids1 ,
110167 graph_infer_state1 ,
111- graph_predict_logics ,
112- graph_predict_logics1 ,
168+ graph_model_output ,
169+ graph_model_output1 ,
113170 ) = self .graph [batch_size ]
114171 graph_input_ids .copy_ (input_ids )
115172 graph_infer_state .copy_for_cuda_graph (infer_state )
116173 graph_input_ids1 .copy_ (input_ids1 )
117174 graph_infer_state1 .copy_for_cuda_graph (infer_state1 )
118175 graph_obj .replay ()
119- return graph_predict_logics , graph_predict_logics1
176+ return graph_model_output , graph_model_output1
120177
121178 def replay (self , input_ids , infer_state , input_ids1 = None , infer_state1 = None ):
122179 if self .enable_decode_microbatch_overlap :
@@ -128,59 +185,50 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
128185 @torch .no_grad ()
129186 def warmup (self , model ):
130187 logger .info ("Begin capture cudagraph, use the --disable_cudagraph to disable it." )
131- for batch_size in range (self .max_batch_size , self .max_batch_size - 1 , - 1 ):
132- # dummy prefill
133- prefill_input_len = 1
134- dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
188+ # for typing easy
189+ from .basemodel import TpPartBaseModel
190+
191+ model : TpPartBaseModel = model
192+
193+ # decode cuda graph init
194+ for batch_size in self .cuda_graph_batch_sizes [::- 1 ]:
195+ seq_len = 2
196+ total_token_num = batch_size * seq_len
197+ max_len_in_batch = self .graph_max_len_in_batch
198+ input_ids = torch .tensor ([1 for _ in range (batch_size )], dtype = torch .int32 , device = "cuda" )
199+ mem_indexes = model .mem_manager .alloc (len (input_ids )).cuda ()
135200 b_req_idx = torch .tensor (
136- [model .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
137- )
138- mem_indexes = model .mem_manager .alloc (len (dummy_input_ids )).cuda ()
139- b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
140- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
141- total_token_num = prefill_input_len * batch_size
142- logics = model .forward (
143- batch_size ,
144- total_token_num ,
145- prefill_input_len ,
146- dummy_input_ids ,
147- mem_indexes ,
148- b_req_idx ,
149- b_seq_len ,
150- b_ready_cache_len = b_ready_cache_len ,
151- is_prefill = True ,
152- multimodal_params = [],
201+ [model .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
153202 )
154- mem_indexes = None
155- prob_out = torch .softmax (logics , dim = - 1 )
156- logics = None
157- predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
158- prob_out = None
159- predict_ids = predict_ids .detach ().cpu ().numpy ()
160- torch .cuda .empty_cache ()
203+ b_seq_len = torch .empty (batch_size , dtype = torch .int32 , device = "cuda" )
204+ b_seq_len .fill_ (seq_len )
161205
162- # dummy decoding, capture the cudagraph
163- total_token_num += batch_size
164- b_seq_len += 1
165- mem_indexes = model .mem_manager .alloc (len (predict_ids )).cuda ()
166- logics = model .forward (
167- batch_size ,
168- total_token_num ,
169- prefill_input_len + 1 ,
170- torch .from_numpy (predict_ids ).cuda ().reshape (- 1 ),
171- mem_indexes ,
172- b_req_idx ,
173- b_seq_len ,
206+ model_input = ModelInput (
207+ batch_size = batch_size ,
208+ total_token_num = total_token_num ,
209+ max_len_in_batch = max_len_in_batch ,
210+ input_ids = input_ids ,
211+ mem_indexes = mem_indexes ,
212+ b_req_idx = b_req_idx ,
213+ b_seq_len = b_seq_len ,
174214 is_prefill = False ,
215+ ** model ._gen_special_model_input (batch_size ),
175216 )
176- mem_indexes = None
217+ model_output : ModelOutput = model .forward (model_input )
218+ del model_output
219+ del input_ids
220+ del mem_indexes
221+ del b_req_idx
222+ del b_seq_len
223+
177224 model .mem_manager .free_all ()
178225 model .req_manager .free_all ()
179226 # release local tensors
180227 for var_name , var_value in list (locals ().items ()):
181228 if isinstance (var_value , torch .Tensor ):
182229 del locals ()[var_name ]
183230 torch .cuda .empty_cache ()
231+
184232 logger .info (
185233 f"Capture cudagraph success, batch_size <={ self .max_batch_size } "
186234 f"and max_len_in_batch <= { self .graph_max_len_in_batch } will infer with cudagraph."
@@ -189,64 +237,52 @@ def warmup(self, model):
189237 @torch .no_grad ()
190238 def warmup_overlap (self , model ):
191239 logger .info ("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it." )
192- for batch_size in range (self .max_batch_size , 0 , - 1 ):
240+ # for typing easy
241+ from .basemodel import TpPartBaseModel
242+
243+ model : TpPartBaseModel = model
244+
245+ for batch_size in self .cuda_graph_batch_sizes [::- 1 ]:
193246 decode_batches = []
194247 for micro_batch_index in [0 , 1 ]:
195- # dummy prefill
196- prefill_input_len = 1
197- dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
248+ # dummy decoding, capture the cudagraph
249+ seq_len = 2
250+ total_token_num = batch_size * seq_len
251+ max_len_in_batch = self .graph_max_len_in_batch
252+ input_ids = torch .tensor ([1 for _ in range (batch_size )], dtype = torch .int32 , device = "cuda" )
253+ mem_indexes = model .mem_manager .alloc (len (input_ids )).cuda ()
198254 b_req_idx = torch .tensor (
199- [model .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
200- )
201- mem_indexes = model .mem_manager .alloc (len (dummy_input_ids )).cuda ()
202- b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
203- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
204- total_token_num = prefill_input_len * batch_size
205- logics = model .forward (
206- batch_size ,
207- total_token_num ,
208- prefill_input_len ,
209- dummy_input_ids ,
210- mem_indexes ,
211- b_req_idx ,
212- b_seq_len ,
213- b_ready_cache_len = b_ready_cache_len ,
214- is_prefill = True ,
215- multimodal_params = [],
255+ [model .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
216256 )
217- mem_indexes = None
218- prob_out = torch .softmax (logics , dim = - 1 )
219- logics = None
220- predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
221- prob_out = None
222- predict_ids = predict_ids .detach ().cpu ().numpy ()
223- torch .cuda .empty_cache ()
224-
225- # dummy decoding, capture the cudagraph
226- total_token_num += batch_size
227- b_seq_len += 1
228- mem_indexes = model .mem_manager .alloc (len (predict_ids )).cuda ()
257+ b_seq_len = torch .empty (batch_size , dtype = torch .int32 , device = "cuda" )
258+ b_seq_len .fill_ (seq_len )
229259
230- micro_batch = DecodeMicroBatch (
260+ micro_batch = ModelInput (
261+ is_prefill = False ,
231262 batch_size = batch_size ,
232263 total_token_num = total_token_num ,
233- max_len_in_batch = prefill_input_len + 1 ,
234- input_ids = torch . from_numpy ( predict_ids ). cuda (). reshape ( - 1 ) ,
264+ max_len_in_batch = max_len_in_batch ,
265+ input_ids = input_ids ,
235266 mem_indexes = mem_indexes ,
236267 b_req_idx = b_req_idx ,
237268 b_seq_len = b_seq_len ,
269+ ** model ._gen_special_model_input (batch_size ),
238270 )
239271 decode_batches .append (micro_batch )
272+ del micro_batch
240273
241274 for var_name , var_value in list (locals ().items ()):
242275 if isinstance (var_value , torch .Tensor ):
243276 del locals ()[var_name ]
244277 torch .cuda .empty_cache ()
278+
245279 _ , _ = model .microbatch_overlap_decode (decode_batches [0 ], decode_batches [1 ])
246280
247281 model .mem_manager .free_all ()
248282 model .req_manager .free_all ()
249283
284+ del decode_batches
285+
250286 # release local tensors
251287 for var_name , var_value in list (locals ().items ()):
252288 if isinstance (var_value , torch .Tensor ):
0 commit comments