@@ -23,7 +23,9 @@ def split_kwargs(
2323 b_seq_len : torch .Tensor ,
2424 b_ready_cache_len : torch .Tensor = None ,
2525 multimodal_params = None ,
26- is_prefill = True ):
26+ is_prefill = True ,
27+ split_n = 2 ):
28+ kwargs = [None ] * split_n
2729 half_batch = batch_size // 2
2830 b_req_idx1 = b_req_idx [:half_batch ]
2931 b_req_idx2 = b_req_idx [half_batch :]
@@ -77,6 +79,92 @@ def split_kwargs(
7779 kwargs2 ["multimodal_params" ] = multimodal_params
7880 return kwargs1 , kwargs2
7981
82+ def split_kwargs_n (
83+ batch_size ,
84+ total_token_num ,
85+ max_len_in_batch ,
86+ input_ids : torch .Tensor ,
87+ mem_indexes : torch .Tensor ,
88+ b_req_idx : torch .Tensor ,
89+ b_start_loc : torch .Tensor ,
90+ b_seq_len : torch .Tensor ,
91+ b_ready_cache_len : torch .Tensor = None ,
92+ multimodal_params = None ,
93+ is_prefill = True ,
94+ split_n = 2 ):
95+
96+ kwargs = [None ] * split_n
97+
98+ # 计算每个分片的批次大小
99+ batch_per_split = [batch_size // split_n ] * split_n
100+ # 处理不能整除的情况
101+ for i in range (batch_size % split_n ):
102+ batch_per_split [i ] += 1
103+
104+ # 准备分割索引
105+ batch_indices = [0 ]
106+ for size in batch_per_split :
107+ batch_indices .append (batch_indices [- 1 ] + size )
108+
109+ # 记录到目前为止的token数
110+ cumulative_tokens = 0
111+
112+ # 为每个分片创建kwargs
113+ for i in range (split_n ):
114+ start_idx = batch_indices [i ]
115+ end_idx = batch_indices [i + 1 ]
116+
117+ # 分割批次相关的张量
118+ split_b_req_idx = b_req_idx [start_idx :end_idx ]
119+ split_b_seq_len = b_seq_len [start_idx :end_idx ]
120+
121+ # 计算该分片的token数量
122+ split_tokens = split_b_seq_len .sum ().item ()
123+
124+ if is_prefill :
125+ # 在prefill阶段,根据token分割
126+ token_start = cumulative_tokens
127+ token_end = token_start + split_tokens
128+ split_input_ids = input_ids [token_start :token_end ]
129+ split_mem_indexes = mem_indexes [token_start :token_end ]
130+ else :
131+ # 在decode阶段,根据批次分割
132+ split_input_ids = input_ids [start_idx :end_idx ]
133+ split_mem_indexes = mem_indexes [start_idx :end_idx ]
134+
135+ # 计算此分片的其他参数
136+ split_max_len = split_b_seq_len .max ().item () if len (split_b_seq_len ) > 0 else 0
137+ split_b_start_loc = split_b_seq_len .cumsum (dim = 0 ) - split_b_seq_len
138+
139+ # 处理缓存长度
140+ split_b_ready_cache_len = None
141+ if b_ready_cache_len is not None :
142+ split_b_ready_cache_len = b_ready_cache_len [start_idx :end_idx ]
143+
144+ # 创建kwargs字典
145+ kwargs [i ] = {
146+ "batch_size" : len (split_b_req_idx ),
147+ "total_token_num" : split_tokens ,
148+ "max_len_in_batch" : split_max_len ,
149+ "input_ids" : split_input_ids ,
150+ "mem_indexes" : split_mem_indexes ,
151+ "b_req_idx" : split_b_req_idx ,
152+ "b_start_loc" : split_b_start_loc ,
153+ "b_seq_len" : split_b_seq_len ,
154+ "b_ready_cache_len" : split_b_ready_cache_len ,
155+ "is_prefill" : is_prefill ,
156+ "all_reduce_id" : i ,
157+ }
158+
159+ # 如果有多模态参数,添加到kwargs
160+ if multimodal_params is not None :
161+ kwargs [i ]["multimodal_params" ] = multimodal_params
162+
163+ # 更新累计token数
164+ cumulative_tokens += split_tokens
165+
166+ return kwargs
167+
80168class ContinuesBatchBackend (ModeBackend ):
81169 def __init__ (self ) -> None :
82170 super ().__init__ ()
@@ -106,14 +194,16 @@ def prefill(self, reqs: List[Tuple]):
106194 def decode (self ):
107195 kwargs , run_reqs = prepare_decode_inputs (g_infer_context .infer_req_ids )
108196 # logits = self.model.forward(**kwargs)
109- if kwargs ["batch_size" ] > 1 :
110- kwargs1 , kwargs2 = split_kwargs (** kwargs )
111- with torch .cuda .stream (self .model .stream [0 ]):
112- logits1 = self .model .forward (** kwargs1 )
113- with torch .cuda .stream (self .model .stream [1 ]):
114- logits2 = self .model .forward (** kwargs2 )
197+ split_n = self .model .stream_num
198+ if kwargs ["batch_size" ] > split_n - 1 :
199+ kwargs_list = split_kwargs_n (** kwargs , split_n = split_n )
200+ logits = [None ] * split_n
201+ for i in range (split_n ):
202+ with torch .cuda .stream (self .model .stream [i ]):
203+ logits [i ] = self .model .forward (** kwargs_list [i ])
204+
115205 torch .cuda .synchronize ()
116- logits = torch .cat ([ logits1 , logits2 ] , dim = 0 )
206+ logits = torch .cat (logits , dim = 0 )
117207 else :
118208 logits = self .model .forward (** kwargs )
119209
0 commit comments