|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates. |
2 | 2 | from functools import partial |
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | import megatron.core |
5 | 6 | import torch |
|
21 | 22 | class MegatronTrainer(BaseMegatronTrainer): |
22 | 23 |
|
23 | 24 | # Code borrowed from NVIDIA/Megatron-LM |
24 | | - def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor): |
25 | | - """Loss function. |
26 | | -
|
27 | | - Args: |
28 | | - output_tensor (torch.Tensor): The tensor with the losses |
29 | | - loss_mask (torch.Tensor): Used to mask out some portions of the loss |
30 | | -
|
31 | | - Returns: |
32 | | - the loss scalar for this micro-batch |
33 | | - the number of non-padded tokens in this microbatch |
34 | | - a dict containing reporting metrics on the loss and number of tokens across |
35 | | - the data parallel ranks |
36 | | - """ |
| 25 | + def loss_func(self, |
| 26 | + output_tensor: torch.Tensor, |
| 27 | + *, |
| 28 | + labels: torch.Tensor, |
| 29 | + loss_scale: Optional[torch.Tensor] = None): |
37 | 30 | args = get_args() |
38 | 31 |
|
39 | 32 | losses = output_tensor.float() |
40 | | - loss_mask = loss_mask.view(-1).float() |
41 | | - total_tokens = loss_mask.sum() |
42 | | - loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) |
| 33 | + if loss_scale is not None: |
| 34 | + losses = losses * loss_scale |
| 35 | + loss_mask = labels != -100 |
| 36 | + loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)]) |
43 | 37 |
|
44 | 38 | megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') |
45 | 39 | if args.context_parallel_size > 1 and not megatron_core_013: |
@@ -109,9 +103,4 @@ def forward_step(self, data_iterator, model): |
109 | 103 | with self.stimer: |
110 | 104 | output_tensor = model(**data) |
111 | 105 | labels = data.get('labels') |
112 | | - if loss_scale is None: |
113 | | - loss_mask = None if labels is None else (labels != -100).float() |
114 | | - else: |
115 | | - loss_scale[labels == -100] = 0 |
116 | | - loss_mask = loss_scale |
117 | | - return output_tensor, partial(self.loss_func, loss_mask=loss_mask) |
| 106 | + return output_tensor, partial(self.loss_func, labels=labels, loss_scale=loss_scale) |
0 commit comments