Skip to content

Commit 98a93ab

Browse files
committed
[bugfix] fix megatron loss_scale (#5406)
1 parent 8a527fe commit 98a93ab

File tree

4 files changed

+14
-31
lines changed

4 files changed

+14
-31
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,5 +754,5 @@ qwen2_5_omni除了包含qwen2_5_vl和qwen2_audio的模型特定参数外,还
754754
- NNODES: torchrun中`--nnodes`的参数透传。
755755
- NODE_RANK: torchrun中`--node_rank`的参数透传。
756756
- LOG_LEVEL: 日志的level,默认为'INFO',你可以设置为'WARNING', 'ERROR'等。
757-
- SWIFT_DEBUG: 在`engine.infer(...)`时,若设置为'1',则会打印input_ids和generate_ids的内容
757+
- SWIFT_DEBUG: 在`engine.infer(...)`时,若设置为'1',PtEngine将会打印input_ids和generate_ids的内容
758758
- VLLM_USE_V1: 用于切换vLLM使用V0/V1版本。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,5 +771,5 @@ For the meaning of the arguments, please refer to [here](https://modelscope.cn/m
771771
- NNODES: Pass-through for the `--nnodes` parameter in torchrun.
772772
- NODE_RANK: Pass-through for the `--node_rank` parameter in torchrun.
773773
- LOG_LEVEL: The log level, default is 'INFO'. You can set it to 'WARNING', 'ERROR', etc.
774-
- SWIFT_DEBUG: During `engine.infer(...)`, if set to '1', the content of input_ids and generate_ids will be printed.
774+
- SWIFT_DEBUG: When set to '1', the PtEngine will print the contents of input_ids and generate_ids during `engine.infer(...)`.
775775
- VLLM_USE_V1: Used to switch between V0 and V1 versions of vLLM.

swift/megatron/init.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ def forward(
439439
output, bias = self.linear_proj(core_attn_out)
440440

441441
return output, bias
442-
pass
443442

444443
MultiLatentAttention.forward = forward
445444

@@ -555,12 +554,7 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po
555554
sequence_start = inference_context.sequence_len_offset
556555
sequence_end = sequence_start + q_len
557556
rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
558-
else:
559-
# Shorten rotary_pos_emb to the sequence length when inference_params
560-
# is not provided. This makes sure we can run forward directly with
561-
# any sequence length. During training, the sequence length is always
562-
# the full rotary_pos_emb length.
563-
rotary_pos_emb = rotary_pos_emb[0:q_len]
557+
# Remove the else branch to fix cp.
564558

565559
# [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim]
566560
k_pos_emb = torch.unsqueeze(k_pos_emb, -2)

swift/megatron/trainers/trainer.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from functools import partial
3+
from typing import Optional
34

45
import megatron.core
56
import torch
@@ -21,25 +22,18 @@
2122
class MegatronTrainer(BaseMegatronTrainer):
2223

2324
# Code borrowed from NVIDIA/Megatron-LM
24-
def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
25-
"""Loss function.
26-
27-
Args:
28-
output_tensor (torch.Tensor): The tensor with the losses
29-
loss_mask (torch.Tensor): Used to mask out some portions of the loss
30-
31-
Returns:
32-
the loss scalar for this micro-batch
33-
the number of non-padded tokens in this microbatch
34-
a dict containing reporting metrics on the loss and number of tokens across
35-
the data parallel ranks
36-
"""
25+
def loss_func(self,
26+
output_tensor: torch.Tensor,
27+
*,
28+
labels: torch.Tensor,
29+
loss_scale: Optional[torch.Tensor] = None):
3730
args = get_args()
3831

3932
losses = output_tensor.float()
40-
loss_mask = loss_mask.view(-1).float()
41-
total_tokens = loss_mask.sum()
42-
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
33+
if loss_scale is not None:
34+
losses = losses * loss_scale
35+
loss_mask = labels != -100
36+
loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)])
4337

4438
megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
4539
if args.context_parallel_size > 1 and not megatron_core_013:
@@ -109,9 +103,4 @@ def forward_step(self, data_iterator, model):
109103
with self.stimer:
110104
output_tensor = model(**data)
111105
labels = data.get('labels')
112-
if loss_scale is None:
113-
loss_mask = None if labels is None else (labels != -100).float()
114-
else:
115-
loss_scale[labels == -100] = 0
116-
loss_mask = loss_scale
117-
return output_tensor, partial(self.loss_func, loss_mask=loss_mask)
106+
return output_tensor, partial(self.loss_func, labels=labels, loss_scale=loss_scale)

0 commit comments

Comments
 (0)