diff --git "a/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" "b/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" index 3e256a35b8..5d235a6132 100644 --- "a/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" +++ "b/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" @@ -32,10 +32,11 @@ example在[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin SWIFT支持在plugin中定制loss。如果不使用这个能力,默认会使用交叉熵Loss(CE Loss)。开发者可以在这个文件中编写代码,注册后trainer会自动使用你定制的loss方法。 例如在plugin/loss.py中添加下面的代码: ```python -@register_loss_func("custom_loss") -def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: +def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: # Write your own loss calculating here return loss + +loss_mapping['custom_loss'] = custom_loss_func ``` 需要注意的是,loss和trainer训练的任务是强相关的,目前的loss定制针对pt和sft任务,如果是人类对齐任务(例如DPO、PPO等)或分类任务(seq_cls)任务在插件中是无法定制的。 diff --git a/docs/source_en/Customization/Pluginization.md b/docs/source_en/Customization/Pluginization.md index ff0c99af52..7fad067ffe 100644 --- a/docs/source_en/Customization/Pluginization.md +++ b/docs/source_en/Customization/Pluginization.md @@ -37,10 +37,11 @@ SWIFT supports customizing the loss function through plugins. If this feature is For example, adding the following code in `plugin/loss.py`: ```python -@register_loss_func("custom_loss") -def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: +def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: # Write your own loss calculation here return loss + +loss_mapping['custom_loss'] = custom_loss_func ``` It is important to note that the loss function is strongly related to the training task. Currently, loss customization supports PT and SFT tasks. For human alignment tasks (e.g., DPO, PPO) or classification tasks (seq_cls), loss customization through plugins is not supported. diff --git a/requirements/framework.txt b/requirements/framework.txt index d147e5080d..fef91f96fa 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -19,7 +19,7 @@ numpy openai oss2 pandas -peft>=0.11,<0.17 +peft>=0.11,<0.18 pillow PyYAML>=5.4 requests diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 7073c3cb60..13fa40eb0b 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -6,7 +6,7 @@ from transformers import Seq2SeqTrainingArguments from transformers.utils.versions import require_version -from swift.plugin import LOSS_MAPPING +from swift.plugin import loss_mapping from swift.trainers import TrainerFactory from swift.trainers.arguments import TrainArgumentsMixin from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master, @@ -118,7 +118,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra create_checkpoint_symlink: bool = False # plugin - loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) + loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'}) metric: Optional[str] = None # extra diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index b89773bffc..d495506952 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -971,10 +971,6 @@ def _encode_context_list( tokenizer_kwargs = {} if loss_scale_list is None: loss_scale_list = [0.] * len(context_list) - if self.loss_scale.keep_loss_scale: - ignore_loss_scale = False - else: - ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list) for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): if isinstance(context, str): # tokenizer_kwargs is the returned tokenizer_kwargs, @@ -987,9 +983,9 @@ def _encode_context_list( labels += token_list else: labels += [-100] * len(token_list) - if not ignore_loss_scale: + if not self.loss_scale.is_binary: loss_scale.extend([loss_weight] * len(token_list)) - if ignore_loss_scale: + if self.loss_scale.is_binary: loss_scale = None return input_ids, labels, loss_scale, tokenizer_kwargs diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index e89c0e7a4e..a2eb2ac1e1 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -182,6 +182,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta.convert_hf2mcore(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template) + del hf_model logger.info('Successfully transferred HF model weights to MG model.') args.save_args() mg_save_checkpoint(1, [mg_model], None, None, 0) @@ -228,6 +229,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template) + del mg_model logger.info('Successfully transferred MG model weights to HF model.') ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load save_checkpoint( diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 4dc40bcf58..7503784b10 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from .callback import extra_callbacks - from .loss import LOSS_MAPPING, get_loss_func + from .loss import loss_mapping, get_loss_func from .loss_scale import loss_scale_map from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu from .optimizer import optimizers_map @@ -21,7 +21,7 @@ else: _import_structure = { 'callback': ['extra_callbacks'], - 'loss': ['LOSS_MAPPING', 'get_loss_func'], + 'loss': ['loss_mapping', 'get_loss_func'], 'loss_scale': ['loss_scale_map'], 'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'], 'optimizer': ['optimizers_map'], diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 456e78b258..6926a64df5 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -15,78 +15,18 @@ from swift.plugin import MeanMetric -class LossType: - loss_scale = 'loss_scale' - cosine_similarity = 'cosine_similarity' - contrastive = 'contrastive' - online_contrastive = 'online_contrastive' - infonce = 'infonce' - channel_loss = 'channel_loss' - reranker = 'reranker' - generative_reranker = 'generative_reranker' - listwise_reranker = 'listwise_reranker' - listwise_generative_reranker = 'listwise_generative_reranker' - - -LOSS_MAPPING = {} - - -def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None): - loss_info = {} - - if loss_func is not None: - loss_info['loss_func'] = loss_func - LOSS_MAPPING[loss_type] = loss_info - return - - def _register_loss_func(loss_func: Callable) -> Callable: - loss_info['loss_func'] = loss_func - LOSS_MAPPING[loss_type] = loss_info - return loss_func - - return _register_loss_func - - -def ce_loss_func(outputs, labels): +def per_token_loss_func(outputs, labels, **kwargs): logits = outputs.logits - device = logits.device - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:].to(device) - # Save memory - masks = shift_labels != -100 - shift_logits = shift_logits[masks] - shift_labels = shift_labels[masks] - # Flatten the tokens - loss_fct = CrossEntropyLoss(reduction='none') - loss = loss_fct(shift_logits, shift_labels) - return loss, masks - + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + labels = torch.roll(labels, shifts=-1, dims=-1) -# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train -@register_loss_func(LossType.loss_scale) -def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: - """Loss func - - Args: - outputs: The model outputs - labels: The labels - loss_scale: The loss scale - num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100. - - Returns: - - """ - loss, masks = ce_loss_func(outputs, labels) - if loss_scale is not None: - shift_scale = loss_scale[..., 1:].to(masks.device) - shift_scale = shift_scale[masks] - loss = (shift_scale * loss) - if num_items_in_batch is None: - loss = loss.mean() - else: - # compat transformers>=4.46 - loss = loss.sum() / num_items_in_batch + # Flatten the tokens + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none') return loss @@ -117,7 +57,6 @@ class SiameseDistanceMetric(Enum): COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa -@register_loss_func(LossType.cosine_similarity) def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: cos_score_transformation = nn.Identity() loss_fct = MSELoss() @@ -126,7 +65,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch= return loss_fct(output, labels.to(output.dtype).view(-1)) -@register_loss_func(LossType.contrastive) def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: sentence1, sentence2 = _parse_pair_sentence(outputs) distance_metric = SiameseDistanceMetric.COSINE_DISTANCE @@ -390,7 +328,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None): return split_tensors -@register_loss_func(LossType.infonce) def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature # calculate CE across the batch, meaning all samples will be negative except the matching positive @@ -491,7 +428,6 @@ def mask_fake_negative(sim_matrix, sim_labels): return loss -@register_loss_func(LossType.online_contrastive) def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: sentence1, sentence2 = _parse_pair_sentence(outputs) distance_metric = SiameseDistanceMetric.COSINE_DISTANCE @@ -510,13 +446,13 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch return loss -@register_loss_func(LossType.channel_loss) def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None, position_ids=None) -> torch.Tensor: + # Note: loss_scale is not supported at the moment. channels = trainer.args.channels assert channels is not None, 'Please pass --channels as a hyperparameter.' assert sample_channels is not None, 'Data does not have channel field.' @@ -583,7 +519,6 @@ def channel_loss_func(outputs, else total_loss / (total_tokens.float() + 1e-12) -@register_loss_func(LossType.reranker) def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: logits = outputs.logits logits = logits.squeeze(1) @@ -593,7 +528,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> return loss -@register_loss_func(LossType.generative_reranker) def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, trainer=None) -> torch.Tensor: """ Generative reranker loss function. @@ -649,7 +583,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc return loss -@register_loss_func(LossType.listwise_reranker) def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor: """ List-wise reranker loss function. @@ -739,7 +672,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch= return total_loss / num_groups -@register_loss_func(LossType.listwise_generative_reranker) def listwise_generative_reranker_loss(outputs, labels, loss_scale=None, @@ -863,7 +795,23 @@ def listwise_generative_reranker_loss(outputs, return total_loss / num_groups +loss_mapping = { + 'per_token_cross_entropy': per_token_loss_func, + 'channel_loss': channel_loss_func, + # embedding + 'cosine_similarity': cosine_similarity_func, + 'contrastive': contrastive_loss, + 'online_contrastive': online_contrastive_loss, + 'infonce': infonce_loss, + # reranker + 'reranker': reranker_loss, + 'generative_reranker': generative_reranker_loss, + 'listwise_reranker': listwise_reranker_loss, + 'listwise_generative_reranker': listwise_generative_reranker_loss, +} + + def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]: if loss_type is None: return None - return LOSS_MAPPING[loss_type]['loss_func'] + return loss_mapping[loss_type] diff --git a/swift/plugin/loss_scale/loss_scale.py b/swift/plugin/loss_scale/loss_scale.py index 20c2ab8df7..d6904434b7 100644 --- a/swift/plugin/loss_scale/loss_scale.py +++ b/swift/plugin/loss_scale/loss_scale.py @@ -10,18 +10,14 @@ class LossScale: + # Indicates whether loss_scale contains only 0 and 1. + # If set to True, loss_scale will be replaced by labels to stay compatible with + # acceleration techniques such as liger_kernel. + # If set to False, an additional 'loss_scale' key will be stored and the + # corresponding loss function will be used. + is_binary = False loss_scale_config = None # path - def _set_keep_loss_scale(self): - self.keep_loss_scale = False - if self.loss_scale_map is None: - return - res = set() - for v in self.loss_scale_map.values(): - res.update(v) - if len(res - {0., 1.}) > 0: - self.keep_loss_scale = True - def __init__(self): if self.loss_scale_config is not None: path = os.path.dirname(os.path.abspath(__file__)) @@ -30,14 +26,9 @@ def __init__(self): self.loss_scale_map = json.load(json_file) else: self.loss_scale_map = None - self._set_keep_loss_scale() - def get_loss_scale(self, - context: str, - context_type: ContextType, - is_last_round: bool, - *, - query: Optional[str] = None) -> Tuple[List[str], List[float]]: + def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, + **kwargs) -> Tuple[List[str], List[float]]: """Calculate loss scale Args: @@ -80,7 +71,12 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me return res_context_list, res_loss_scale +class DefaultLossScale(LossScale): + is_binary = True + + class LastRoundLossScale(LossScale): + is_binary = True def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs): if context_type == ContextType.RESPONSE: @@ -129,6 +125,7 @@ class AlphaUmiLossScale(REACTLossScale): class TrainAllLossScale(LossScale): + is_binary = True def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs): return [context], [1.] @@ -136,10 +133,12 @@ def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwarg class IgnoreEmptyThink(REACTLossScale): loss_scale_config = 'ignore_empty_think.json' + is_binary = True class LastRoundWithIgnoreEmptyThink(LossScale): loss_scale_config = 'ignore_empty_think.json' + is_binary = True def get_loss_scale(self, context: str, @@ -159,7 +158,7 @@ def get_loss_scale(self, # Add your loss scale here, use --loss_scale xxx to train loss_scale_map = { 'last_round': LastRoundLossScale, - 'default': LossScale, + 'default': DefaultLossScale, 'all': TrainAllLossScale, 'ignore_empty_think': IgnoreEmptyThink, 'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink, diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 579d2b5756..d92f8b8dc5 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -786,18 +786,26 @@ def _evalscope_eval(self): self.model.train() return eval_dict - def get_logits_to_keep(self, labels): + def prepare_logits_to_keep(self, inputs): + labels = inputs['labels'] + loss_scale = inputs.get('loss_scale') if labels.shape[0] == 1 and not is_mp(): # device_map may encounter device mismatch issues. loss_mask = (labels != -100)[0] labels = labels[:, loss_mask] labels = nn.functional.pad(labels, (1, 0), value=-100) + if loss_scale is not None: + loss_scale = loss_scale[:, loss_mask] + inputs['loss_scale'] = nn.functional.pad(loss_scale, (1, 0), value=0) logits_to_keep = nn.functional.pad(loss_mask[1:], (0, 1), value=True) else: logits_to_keep = labels.shape[-1] - ((labels != -100).int().argmax(-1).min().item()) + 1 assert logits_to_keep > 0 labels = labels[:, -logits_to_keep:] - return labels, logits_to_keep + if loss_scale is not None: + inputs['loss_scale'] = loss_scale[:, -logits_to_keep:] + inputs['labels'] = labels + inputs['logits_to_keep'] = logits_to_keep def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor: from swift.llm import get_packed_seq_params diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 02190d32f4..209fcc8938 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -68,15 +68,13 @@ def concatenated_forward( is_ref_model: bool = False ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch = batch.copy() - labels = batch.pop('labels', None) use_logits_to_keep = self.get_use_logits_to_keep(self.template.sequence_parallel_size == 1) if use_logits_to_keep: - labels, logits_to_keep = self.get_logits_to_keep(labels) - if logits_to_keep is not None: - batch['logits_to_keep'] = logits_to_keep + self.prepare_logits_to_keep(batch) if self.aux_loss_enabled: batch['output_router_logits'] = True + labels = batch.pop('labels', None) if self.is_encoder_decoder: batch['labels'] = labels position_ids = batch.pop('_position_ids', None) diff --git a/swift/trainers/rlhf_trainer/gkd_trainer.py b/swift/trainers/rlhf_trainer/gkd_trainer.py index e8ba103505..e844131ca4 100644 --- a/swift/trainers/rlhf_trainer/gkd_trainer.py +++ b/swift/trainers/rlhf_trainer/gkd_trainer.py @@ -88,9 +88,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # If generate is used, then use_logits_to_keep must be set to False. use_logits_to_keep = self.get_use_logits_to_keep(True) if use_logits_to_keep: - inputs['labels'], logits_to_keep = self.get_logits_to_keep(inputs['labels']) - if logits_to_keep is not None: - model_inputs['logits_to_keep'] = logits_to_keep + self.prepare_logits_to_keep(inputs) + model_inputs['logits_to_keep'] = inputs['logits_to_keep'] if self.args.sft_alpha > 0: model_inputs['labels'] = inputs['labels'] # compute student output diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index fa5b4da313..1aef292167 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -306,14 +306,8 @@ def _prepare_inputs(self, inputs): from swift.llm import HfConfigFactory args = self.args inputs = super()._prepare_inputs(inputs) - from swift.plugin.loss import get_loss_func loss_kwargs = {} compute_loss_func = self.compute_loss_func - loss_scale = inputs.pop('loss_scale', None) - if loss_scale is not None: - loss_kwargs['loss_scale'] = loss_scale - if compute_loss_func is None: - compute_loss_func = get_loss_func('loss_scale') sample_channels = inputs.pop('channel', None) position_ids = inputs.pop('_position_ids', None) @@ -330,14 +324,11 @@ def _prepare_inputs(self, inputs): if position_ids is not None: loss_kwargs['position_ids'] = position_ids - use_logits_to_keep = self.get_use_logits_to_keep('labels' in inputs and self.label_smoother is None - and compute_loss_func is None) + use_logits_to_keep = self.get_use_logits_to_keep(True) if use_logits_to_keep: - inputs['labels'], logits_to_keep = self.get_logits_to_keep(inputs['labels']) - if logits_to_keep is not None: - inputs['logits_to_keep'] = logits_to_keep - if args.tuner_backend == 'unsloth' and isinstance(logits_to_keep, torch.Tensor): - inputs['logits_to_keep'] = int(logits_to_keep.sum()) + self.prepare_logits_to_keep(inputs) + if args.tuner_backend == 'unsloth' and isinstance(inputs['logits_to_keep'], torch.Tensor): + inputs['logits_to_keep'] = int(inputs['logits_to_keep'].sum()) base_model = self.template.get_base_model(self.model) if self.model.model_info.is_moe_model and 'output_router_logits' in inspect.signature( @@ -352,11 +343,14 @@ def _prepare_inputs(self, inputs): return inputs def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + from swift.plugin import get_loss_func labels = None compute_loss_func = inputs.pop('compute_loss_func', None) + loss_scale = inputs.pop('loss_scale', None) loss_kwargs = inputs.pop('loss_kwargs', {}) - if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs: + if (self.label_smoother is not None or compute_loss_func is not None + or loss_scale is not None) and 'labels' in inputs: labels = inputs.pop('labels') outputs = model(**inputs) if getattr(outputs, 'aux_loss', None) is not None: @@ -382,6 +376,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] else: + outputs.loss = None + if loss_scale is not None: + loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1) + outputs.loss = get_loss_func('per_token_cross_entropy')(outputs, labels) + outputs.loss = outputs.loss * loss_scale unwrapped_model = self.accelerator.unwrap_model(model) if is_peft_available() and isinstance(unwrapped_model, PeftModel): model_name = unwrapped_model.model._get_name() @@ -390,10 +389,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # User-defined compute_loss function if compute_loss_func is not None: loss = compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs) - elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): - loss = self.label_smoother(outputs, labels, shift_labels=True) + elif self.label_smoother is not None: + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) else: - loss = self.label_smoother(outputs, labels) + if num_items_in_batch is None: + num_items_in_batch = (labels[:, 1:] != -100).sum() + loss = outputs.loss.sum() / num_items_in_batch if self.model.model_info.is_moe_model and self.args.router_aux_loss_coef is not None: aux_loss = outputs.get('aux_loss')