@@ -89,9 +89,11 @@ def split_kwargs_n(
8989 b_ready_cache_len : torch .Tensor = None ,
9090 multimodal_params = None ,
9191 is_prefill = True ,
92- split_n = 2 ):
92+ split_n = 2 ,
93+ run_reqs = None ):
9394
9495 kwargs = [None ] * split_n
96+ run_reqs_list = [None ] * split_n
9597
9698 # 计算每个分片的批次大小
9799 batch_per_split = [batch_size // split_n ] * split_n
@@ -124,11 +126,13 @@ def split_kwargs_n(
124126 token_start = cumulative_tokens
125127 token_end = token_start + split_tokens
126128 split_input_ids = input_ids [token_start :token_end ]
129+ reqs = run_reqs [token_start :token_end ]
127130 split_mem_indexes = mem_indexes [token_start :token_end ]
128131 else :
129132 # 在decode阶段,根据批次分割
130133 split_input_ids = input_ids [start_idx :end_idx ]
131134 split_mem_indexes = mem_indexes [start_idx :end_idx ]
135+ reqs = run_reqs [start_idx :end_idx ]
132136
133137 # 计算此分片的其他参数
134138 split_max_len = split_b_seq_len .max ().item () if len (split_b_seq_len ) > 0 else 0
@@ -139,6 +143,7 @@ def split_kwargs_n(
139143 if b_ready_cache_len is not None :
140144 split_b_ready_cache_len = b_ready_cache_len [start_idx :end_idx ]
141145
146+ run_reqs_list [i ] = reqs
142147 # 创建kwargs字典
143148 kwargs [i ] = {
144149 "batch_size" : len (split_b_req_idx ),
@@ -161,7 +166,7 @@ def split_kwargs_n(
161166 # 更新累计token数
162167 cumulative_tokens += split_tokens
163168
164- return kwargs
169+ return kwargs , run_reqs_list
165170
166171class ContinuesBatchBackend (ModeBackend ):
167172 def __init__ (self ) -> None :
@@ -175,7 +180,10 @@ def prefill(self, reqs: List[Tuple], stream_id):
175180 with torch .cuda .stream (self .model .stream [stream_id ]):
176181 logits = self .model .forward (** kwargs )
177182 next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
178- torch .cuda .current_stream ().synchronize ()
183+ next_token_ids = next_token_ids .detach ().cpu ().numpy ()
184+ next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
185+ self .post_handel (run_reqs , next_token_ids , next_token_logprobs , stream_id )
186+ # torch.cuda.current_stream().synchronize()
179187
180188 # logits = self.model.forward(**kwargs)
181189
@@ -193,10 +201,10 @@ def prefill(self, reqs: List[Tuple], stream_id):
193201 # logits = self.model.forward(**kwargs)
194202
195203 # next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
196- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
197- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
204+ # next_token_ids = next_token_ids.detach().cpu().numpy()
205+ # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
198206
199- return self .post_handel (run_reqs , next_token_ids , next_token_logprobs , stream_id )
207+ # self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
200208 # return
201209
202210 def decode (self , stream_id ):
@@ -207,13 +215,15 @@ def decode(self, stream_id):
207215 with torch .cuda .stream (self .model .stream [stream_id ]):
208216 logits = self .model .forward (** kwargs )
209217 next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
210- torch .cuda .current_stream ().synchronize ()
218+ next_token_ids = next_token_ids .detach ().cpu ().numpy ()
219+ next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
220+ self .post_handel (run_reqs , next_token_ids , next_token_logprobs , stream_id )
211221
212222 # logits = self.model.forward(**kwargs)
213223
214224 # split_n = self.model.stream_num
215225 # if kwargs["batch_size"] > split_n - 1:
216- # kwargs_list = split_kwargs_n(**kwargs, split_n=split_n)
226+ # kwargs_list, run_reqs_list = split_kwargs_n(**kwargs, split_n=split_n, run_reqs=run_reqs )
217227 # logits = [None] * split_n
218228 # for i in range(split_n):
219229 # with torch.cuda.stream(self.model.stream[i]):
@@ -225,11 +235,10 @@ def decode(self, stream_id):
225235 # logits = self.model.forward(**kwargs)
226236
227237 # next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
228- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
229- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
230-
231- self .post_handel (run_reqs , next_token_ids , next_token_logprobs , stream_id )
232- return stream_id
238+ # next_token_ids = next_token_ids.detach().cpu().numpy()
239+ # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
240+ # self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
241+ return
233242
234243 def post_handel (self , run_reqs : List [InferReq ], next_token_ids , next_token_logprobs , stream_id ):
235244 finished_req_ids = []
0 commit comments