Skip to content

Commit 8cad75e

Browse files
authored
fix overflow of efficient_attn_ratio (#1436)
1 parent 6a4f3f5 commit 8cad75e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

xtuner/v1/engine/train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
281281
step_consumed_tokens += seq_ctx.mask.sum()
282282

283283
num_tokens = seq_ctx.cu_seq_lens_k[1:] - seq_ctx.cu_seq_lens_k[:-1]
284-
efficient_forward_tokens += (num_tokens**2).sum()
285-
total_forward_tokens += (num_tokens.sum()) ** 2
284+
efficient_forward_tokens += (num_tokens.long() ** 2).sum()
285+
total_forward_tokens += (num_tokens.long().sum()) ** 2
286286

287287
if self.intra_layer_micro_batch == 1:
288288
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])

xtuner/v1/engine/vision_compose_train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
167167
step_consumed_img_tokens /= seq_ctx.sequence_parallel_mesh.size()
168168

169169
num_tokens = seq_ctx.cu_seq_lens_k[1:] - seq_ctx.cu_seq_lens_k[:-1]
170-
efficient_forward_tokens += (num_tokens**2).sum()
171-
total_forward_tokens += (num_tokens.sum()) ** 2
170+
efficient_forward_tokens += (num_tokens.long() ** 2).sum()
171+
total_forward_tokens += (num_tokens.long().sum()) ** 2
172172

173173
# todo: support intra_layer_micro_batch
174174
output = self.model(seq_ctx=seq_ctx_list[0], loss_ctx=loss_ctx_list[0])

0 commit comments

Comments
 (0)