Skip to content

Commit d9924cc

Browse files
Fix moe loss and sp (#5316)
1 parent df7535d commit d9924cc

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
from swift.llm import get_llm_model
1010
from .base import CommonSequenceParallel
11-
from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_and_entropies_grpo,
12-
_get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader,
13-
get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc,
14-
split_by_mini_batches_grpo)
11+
from .utils import (GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler,
12+
_get_per_token_logps_and_entropies_grpo, _get_train_sampler_grpo, _prepare_inputs,
13+
_prepare_inputs_grpo, get_common_dataloader, get_per_token_logps, loss_scale_sp_func,
14+
old_policy_grpo, setup_compute_acc, split_by_mini_batches_grpo)
1515

1616
assert version.parse(torch.__version__) >= version.parse('2.0.0')
1717
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -233,6 +233,30 @@ def pre_forward_split_hook(_self, args, kwargs):
233233
if hasattr(base_model, '_update_causal_mask'):
234234
self.causal_mask_func = base_model._update_causal_mask
235235
base_model.register_forward_pre_hook(pre_forward_split_hook, with_kwargs=True)
236+
base_model: torch.nn.Module
237+
238+
def moe_aux_loss_hook(module, args, kwargs, output):
239+
router_logits = getattr(output, 'router_logits', None)
240+
if router_logits is None:
241+
return output
242+
243+
attention_mask = kwargs['attention_mask']
244+
num_layers = len(router_logits)
245+
sp_len = router_logits[0].shape[0]
246+
if isinstance(router_logits, tuple):
247+
compute_device = router_logits[0].device
248+
router_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in router_logits], dim=0)
249+
router_logits, _ = GatherLoss.apply(router_logits, None, self.sp_group)
250+
router_logits = router_logits.reshape(self.sp_world_size, num_layers, sp_len,
251+
-1).transpose(0, 1).reshape(num_layers, self.sp_world_size * sp_len,
252+
-1)
253+
if attention_mask is not None:
254+
router_logits = router_logits[:, :attention_mask.shape[1], :]
255+
output['router_logits'] = tuple([logit.squeeze() for logit in router_logits.split(1, dim=0)])
256+
return output
257+
258+
if model.model_info.is_moe_model:
259+
base_model.register_forward_hook(moe_aux_loss_hook, with_kwargs=True)
236260
self.model_dtype = next(model.parameters()).dtype
237261
self.tokenizer = tokenizer
238262

swift/trainers/sequence_parallel/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from contextlib import contextmanager
55
from functools import partial
6-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple
6+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
77

88
import datasets
99
import numpy as np
@@ -50,19 +50,25 @@ def forward(ctx, loss, labels, process_group, gather_idx=None):
5050
gather_idx: gather the tensors on this dim
5151
"""
5252
ctx.process_group = process_group
53-
shape0 = labels.shape[0]
54-
ctx.scatter_shape = labels.shape[gather_idx or 0]
53+
# change from label.shape to loss, because label may be None
54+
shape0 = loss.shape[0]
55+
ctx.scatter_shape = loss.shape[gather_idx or 0]
5556
ctx.gather_idx = gather_idx or 0
5657
world_size = dist.get_world_size(group=process_group) # the sp world size
5758
output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device)
5859
# gather all from sp group
5960
dist.all_gather_into_tensor(output, loss, group=process_group)
6061
if gather_idx is not None:
6162
output = torch.cat(output.split(shape0, dim=0), dim=gather_idx)
62-
labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device)
63-
dist.all_gather_into_tensor(labels_output, labels, group=process_group)
64-
if gather_idx is not None:
65-
labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx)
63+
if labels is not None:
64+
labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]),
65+
dtype=labels.dtype,
66+
device=labels.device)
67+
dist.all_gather_into_tensor(labels_output, labels, group=process_group)
68+
if gather_idx is not None:
69+
labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx)
70+
else:
71+
labels_output = None
6672
return output, labels_output
6773

6874
@staticmethod

swift/trainers/trainers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
395395
else:
396396
loss = self.label_smoother(outputs, labels)
397397

398+
if self.model.model_info.is_moe_model and self.args.router_aux_loss_coef is not None:
399+
aux_loss = outputs.get('aux_loss')
400+
if aux_loss is not None:
401+
loss = loss + self.args.router_aux_loss_coef * aux_loss.to(loss.device)
402+
398403
if self.template.sequence_parallel_size > 1:
399404
from swift.trainers.sequence_parallel import sequence_parallel
400405
loss = sequence_parallel.reduce_outputs(loss, labels)

0 commit comments

Comments
 (0)