Skip to content

Commit 8b5559a

Browse files
committed
fix microbatch_index.
1 parent a9f283c commit 8b5559a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def forward(self, model_input: ModelInput):
241241
else:
242242
return self._decode(model_input)
243243

244-
def _create_inferstate(self, model_input: ModelInput, batch_index: int = 0):
244+
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
245245
infer_state = self.infer_state_class()
246246
infer_state.is_prefill = model_input.is_prefill
247247
infer_state.is_token_healing = self.is_token_healing
@@ -269,7 +269,8 @@ def _create_inferstate(self, model_input: ModelInput, batch_index: int = 0):
269269
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
270270
self.data_type,
271271
)
272-
infer_state.dist_group = dist_group_manager.get_group(batch_index)
272+
infer_state.microbatch_index = microbatch_index
273+
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)
273274

274275
# 特殊模型,特殊模式的特定变量初始化操作。
275276
infer_state.deepseekv3_mtp_draft_input_hiddens = model_input.deepseekv3_mtp_draft_input_hiddens

0 commit comments

Comments
 (0)