diff --git a/swift/trainers/sequence_parallel/ulysses.py b/swift/trainers/sequence_parallel/ulysses.py index 837d033481..7c0c085d08 100644 --- a/swift/trainers/sequence_parallel/ulysses.py +++ b/swift/trainers/sequence_parallel/ulysses.py @@ -8,10 +8,10 @@ from swift.llm import get_llm_model from .base import CommonSequenceParallel -from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_and_entropies_grpo, - _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader, - get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc, - split_by_mini_batches_grpo) +from .utils import (GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler, + _get_per_token_logps_and_entropies_grpo, _get_train_sampler_grpo, _prepare_inputs, + _prepare_inputs_grpo, get_common_dataloader, get_per_token_logps, loss_scale_sp_func, + old_policy_grpo, setup_compute_acc, split_by_mini_batches_grpo) assert version.parse(torch.__version__) >= version.parse('2.0.0') torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -233,6 +233,30 @@ def pre_forward_split_hook(_self, args, kwargs): if hasattr(base_model, '_update_causal_mask'): self.causal_mask_func = base_model._update_causal_mask base_model.register_forward_pre_hook(pre_forward_split_hook, with_kwargs=True) + base_model: torch.nn.Module + + def moe_aux_loss_hook(module, args, kwargs, output): + router_logits = getattr(output, 'router_logits', None) + if router_logits is None: + return output + + attention_mask = kwargs['attention_mask'] + num_layers = len(router_logits) + sp_len = router_logits[0].shape[0] + if isinstance(router_logits, tuple): + compute_device = router_logits[0].device + router_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in router_logits], dim=0) + router_logits, _ = GatherLoss.apply(router_logits, None, self.sp_group) + router_logits = router_logits.reshape(self.sp_world_size, num_layers, sp_len, + -1).transpose(0, 1).reshape(num_layers, self.sp_world_size * sp_len, + -1) + if attention_mask is not None: + router_logits = router_logits[:, :attention_mask.shape[1], :] + output['router_logits'] = tuple([logit.squeeze() for logit in router_logits.split(1, dim=0)]) + return output + + if model.model_info.is_moe_model: + base_model.register_forward_hook(moe_aux_loss_hook, with_kwargs=True) self.model_dtype = next(model.parameters()).dtype self.tokenizer = tokenizer diff --git a/swift/trainers/sequence_parallel/utils.py b/swift/trainers/sequence_parallel/utils.py index 96b20f1bfd..b8bf8d68e7 100644 --- a/swift/trainers/sequence_parallel/utils.py +++ b/swift/trainers/sequence_parallel/utils.py @@ -3,7 +3,7 @@ import os from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union import datasets import numpy as np @@ -50,8 +50,9 @@ def forward(ctx, loss, labels, process_group, gather_idx=None): gather_idx: gather the tensors on this dim """ ctx.process_group = process_group - shape0 = labels.shape[0] - ctx.scatter_shape = labels.shape[gather_idx or 0] + # change from label.shape to loss, because label may be None + shape0 = loss.shape[0] + ctx.scatter_shape = loss.shape[gather_idx or 0] ctx.gather_idx = gather_idx or 0 world_size = dist.get_world_size(group=process_group) # the sp world size output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device) @@ -59,10 +60,15 @@ def forward(ctx, loss, labels, process_group, gather_idx=None): dist.all_gather_into_tensor(output, loss, group=process_group) if gather_idx is not None: output = torch.cat(output.split(shape0, dim=0), dim=gather_idx) - labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device) - dist.all_gather_into_tensor(labels_output, labels, group=process_group) - if gather_idx is not None: - labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx) + if labels is not None: + labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), + dtype=labels.dtype, + device=labels.device) + dist.all_gather_into_tensor(labels_output, labels, group=process_group) + if gather_idx is not None: + labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx) + else: + labels_output = None return output, labels_output @staticmethod diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 9f23c0752e..b49c43b6d8 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -398,6 +398,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: loss = self.label_smoother(outputs, labels) + if self.model.model_info.is_moe_model and self.args.router_aux_loss_coef is not None: + aux_loss = outputs.get('aux_loss') + if aux_loss is not None: + loss = loss + self.args.router_aux_loss_coef * aux_loss.to(loss.device) + if self.template.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel loss = sequence_parallel.reduce_outputs(loss, labels)