@@ -229,7 +229,6 @@ def forward(
229229 b_seq_len : torch .Tensor ,
230230 b_ready_cache_len : torch .Tensor = None ,
231231 multimodal_params = None ,
232- dist_group : CustomProcessGroup = None ,
233232 is_prefill = True ,
234233 ):
235234 assert mem_indexes .is_cuda
@@ -246,7 +245,6 @@ def forward(
246245 b_seq_len ,
247246 b_ready_cache_len ,
248247 multimodal_params ,
249- dist_group ,
250248 )
251249 else :
252250 return self ._decode (
@@ -259,7 +257,6 @@ def forward(
259257 b_start_loc ,
260258 b_seq_len ,
261259 multimodal_params ,
262- dist_group ,
263260 )
264261
265262 def _prefill (
@@ -274,7 +271,6 @@ def _prefill(
274271 b_seq_len ,
275272 b_ready_cache_len ,
276273 multimodal_params ,
277- dist_group : CustomProcessGroup = None ,
278274 ):
279275 infer_state = self .infer_state_class ()
280276 infer_state .is_prefill = True
@@ -304,7 +300,7 @@ def _prefill(
304300 dtype = self .data_type ,
305301 device = "cuda" ,
306302 )
307- infer_state .dist_group = dist_group if dist_group is not None else dist_group_manager .get_default_group ()
303+ infer_state .dist_group = dist_group_manager .get_default_group ()
308304
309305 init_req_to_token_indexes (
310306 self .req_manager .req_to_token_indexs ,
@@ -330,7 +326,6 @@ def _decode(
330326 b_start_loc ,
331327 b_seq_len ,
332328 multimodal_params ,
333- dist_group : CustomProcessGroup = None ,
334329 ):
335330 infer_state = self .infer_state_class ()
336331 infer_state .is_prefill = False
@@ -356,7 +351,7 @@ def _decode(
356351 dtype = self .data_type ,
357352 device = "cuda" ,
358353 )
359- infer_state .dist_group = dist_group if dist_group is not None else dist_group_manager .get_default_group ()
354+ infer_state .dist_group = dist_group_manager .get_default_group ()
360355 copy_kv_index_to_req (self .req_manager .req_to_token_indexs , b_req_idx , b_seq_len , infer_state .mem_index )
361356
362357 infer_state .init_some_extra_state (self , input_ids )
0 commit comments