Skip to content

Commit c686225

Browse files
authored
fix
1 parent a000ab6 commit c686225

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def forward(
229229
b_seq_len: torch.Tensor,
230230
b_ready_cache_len: torch.Tensor = None,
231231
multimodal_params=None,
232-
dist_group: CustomProcessGroup = None,
233232
is_prefill=True,
234233
):
235234
assert mem_indexes.is_cuda
@@ -246,7 +245,6 @@ def forward(
246245
b_seq_len,
247246
b_ready_cache_len,
248247
multimodal_params,
249-
dist_group,
250248
)
251249
else:
252250
return self._decode(
@@ -259,7 +257,6 @@ def forward(
259257
b_start_loc,
260258
b_seq_len,
261259
multimodal_params,
262-
dist_group,
263260
)
264261

265262
def _prefill(
@@ -274,7 +271,6 @@ def _prefill(
274271
b_seq_len,
275272
b_ready_cache_len,
276273
multimodal_params,
277-
dist_group: CustomProcessGroup = None,
278274
):
279275
infer_state = self.infer_state_class()
280276
infer_state.is_prefill = True
@@ -304,7 +300,7 @@ def _prefill(
304300
dtype=self.data_type,
305301
device="cuda",
306302
)
307-
infer_state.dist_group = dist_group if dist_group is not None else dist_group_manager.get_default_group()
303+
infer_state.dist_group = dist_group_manager.get_default_group()
308304

309305
init_req_to_token_indexes(
310306
self.req_manager.req_to_token_indexs,
@@ -330,7 +326,6 @@ def _decode(
330326
b_start_loc,
331327
b_seq_len,
332328
multimodal_params,
333-
dist_group: CustomProcessGroup = None,
334329
):
335330
infer_state = self.infer_state_class()
336331
infer_state.is_prefill = False
@@ -356,7 +351,7 @@ def _decode(
356351
dtype=self.data_type,
357352
device="cuda",
358353
)
359-
infer_state.dist_group = dist_group if dist_group is not None else dist_group_manager.get_default_group()
354+
infer_state.dist_group = dist_group_manager.get_default_group()
360355
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
361356

362357
infer_state.init_some_extra_state(self, input_ids)

0 commit comments

Comments
 (0)