Skip to content

[plugin] refactor loss_type/loss_scale #5337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/Customization/插件化.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ example在[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin
SWIFT支持在plugin中定制loss。如果不使用这个能力,默认会使用交叉熵Loss(CE Loss)。开发者可以在这个文件中编写代码,注册后trainer会自动使用你定制的loss方法。
例如在plugin/loss.py中添加下面的代码:
```python
@register_loss_func("custom_loss")
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
# Write your own loss calculating here
return loss

loss_mapping['custom_loss'] = custom_loss_func
```
需要注意的是,loss和trainer训练的任务是强相关的,目前的loss定制针对pt和sft任务,如果是人类对齐任务(例如DPO、PPO等)或分类任务(seq_cls)任务在插件中是无法定制的。

Expand Down
5 changes: 3 additions & 2 deletions docs/source_en/Customization/Pluginization.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ SWIFT supports customizing the loss function through plugins. If this feature is
For example, adding the following code in `plugin/loss.py`:

```python
@register_loss_func("custom_loss")
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
# Write your own loss calculation here
return loss

loss_mapping['custom_loss'] = custom_loss_func
```

It is important to note that the loss function is strongly related to the training task. Currently, loss customization supports PT and SFT tasks. For human alignment tasks (e.g., DPO, PPO) or classification tasks (seq_cls), loss customization through plugins is not supported.
Expand Down
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ numpy
openai
oss2
pandas
peft>=0.11,<0.17
peft>=0.11,<0.18
pillow
PyYAML>=5.4
requests
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import Seq2SeqTrainingArguments
from transformers.utils.versions import require_version

from swift.plugin import LOSS_MAPPING
from swift.plugin import loss_mapping
from swift.trainers import TrainerFactory
from swift.trainers.arguments import TrainArgumentsMixin
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master,
Expand Down Expand Up @@ -118,7 +118,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
create_checkpoint_symlink: bool = False

# plugin
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'})
metric: Optional[str] = None

# extra
Expand Down
8 changes: 2 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,6 @@ def _encode_context_list(
tokenizer_kwargs = {}
if loss_scale_list is None:
loss_scale_list = [0.] * len(context_list)
if self.loss_scale.keep_loss_scale:
ignore_loss_scale = False
else:
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str):
# tokenizer_kwargs is the returned tokenizer_kwargs,
Expand All @@ -987,9 +983,9 @@ def _encode_context_list(
labels += token_list
else:
labels += [-100] * len(token_list)
if not ignore_loss_scale:
if not self.loss_scale.is_binary:
loss_scale.extend([loss_weight] * len(token_list))
if ignore_loss_scale:
if self.loss_scale.is_binary:
loss_scale = None
return input_ids, labels, loss_scale, tokenizer_kwargs

Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del hf_model
logger.info('Successfully transferred HF model weights to MG model.')
args.save_args()
mg_save_checkpoint(1, [mg_model], None, None, 0)
Expand Down Expand Up @@ -228,6 +229,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
del mg_model
logger.info('Successfully transferred MG model weights to HF model.')
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
save_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if TYPE_CHECKING:
from .callback import extra_callbacks
from .loss import LOSS_MAPPING, get_loss_func
from .loss import loss_mapping, get_loss_func
from .loss_scale import loss_scale_map
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
from .optimizer import optimizers_map
Expand All @@ -21,7 +21,7 @@
else:
_import_structure = {
'callback': ['extra_callbacks'],
'loss': ['LOSS_MAPPING', 'get_loss_func'],
'loss': ['loss_mapping', 'get_loss_func'],
'loss_scale': ['loss_scale_map'],
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
'optimizer': ['optimizers_map'],
Expand Down
108 changes: 28 additions & 80 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,18 @@
from swift.plugin import MeanMetric


class LossType:
loss_scale = 'loss_scale'
cosine_similarity = 'cosine_similarity'
contrastive = 'contrastive'
online_contrastive = 'online_contrastive'
infonce = 'infonce'
channel_loss = 'channel_loss'
reranker = 'reranker'
generative_reranker = 'generative_reranker'
listwise_reranker = 'listwise_reranker'
listwise_generative_reranker = 'listwise_generative_reranker'


LOSS_MAPPING = {}


def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None):
loss_info = {}

if loss_func is not None:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return

def _register_loss_func(loss_func: Callable) -> Callable:
loss_info['loss_func'] = loss_func
LOSS_MAPPING[loss_type] = loss_info
return loss_func

return _register_loss_func


def ce_loss_func(outputs, labels):
def per_token_loss_func(outputs, labels, **kwargs):
logits = outputs.logits
device = logits.device
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:].to(device)
# Save memory
masks = shift_labels != -100
shift_logits = shift_logits[masks]
shift_labels = shift_labels[masks]
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits, shift_labels)
return loss, masks

# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = torch.roll(labels, shifts=-1, dims=-1)

# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""Loss func

Args:
outputs: The model outputs
labels: The labels
loss_scale: The loss scale
num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.

Returns:

"""
loss, masks = ce_loss_func(outputs, labels)
if loss_scale is not None:
shift_scale = loss_scale[..., 1:].to(masks.device)
shift_scale = shift_scale[masks]
loss = (shift_scale * loss)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
# Flatten the tokens
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none')
return loss


Expand Down Expand Up @@ -117,7 +57,6 @@ class SiameseDistanceMetric(Enum):
COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa


@register_loss_func(LossType.cosine_similarity)
def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
cos_score_transformation = nn.Identity()
loss_fct = MSELoss()
Expand All @@ -126,7 +65,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=
return loss_fct(output, labels.to(output.dtype).view(-1))


@register_loss_func(LossType.contrastive)
def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand Down Expand Up @@ -390,7 +328,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
return split_tensors


@register_loss_func(LossType.infonce)
def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature
# calculate CE across the batch, meaning all samples will be negative except the matching positive
Expand Down Expand Up @@ -491,7 +428,6 @@ def mask_fake_negative(sim_matrix, sim_labels):
return loss


@register_loss_func(LossType.online_contrastive)
def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
sentence1, sentence2 = _parse_pair_sentence(outputs)
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
Expand All @@ -510,13 +446,13 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
return loss


@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs,
labels,
num_items_in_batch=None,
sample_channels=None,
trainer=None,
position_ids=None) -> torch.Tensor:
# Note: loss_scale is not supported at the moment.
channels = trainer.args.channels
assert channels is not None, 'Please pass --channels as a hyperparameter.'
assert sample_channels is not None, 'Data does not have channel field.'
Expand Down Expand Up @@ -583,7 +519,6 @@ def channel_loss_func(outputs,
else total_loss / (total_tokens.float() + 1e-12)


@register_loss_func(LossType.reranker)
def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
logits = outputs.logits
logits = logits.squeeze(1)
Expand All @@ -593,7 +528,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) ->
return loss


@register_loss_func(LossType.generative_reranker)
def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, trainer=None) -> torch.Tensor:
"""
Generative reranker loss function.
Expand Down Expand Up @@ -649,7 +583,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc
return loss


@register_loss_func(LossType.listwise_reranker)
def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
"""
List-wise reranker loss function.
Expand Down Expand Up @@ -739,7 +672,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
return total_loss / num_groups


@register_loss_func(LossType.listwise_generative_reranker)
def listwise_generative_reranker_loss(outputs,
labels,
loss_scale=None,
Expand Down Expand Up @@ -863,7 +795,23 @@ def listwise_generative_reranker_loss(outputs,
return total_loss / num_groups


loss_mapping = {
'per_token_cross_entropy': per_token_loss_func,
'channel_loss': channel_loss_func,
# embedding
'cosine_similarity': cosine_similarity_func,
'contrastive': contrastive_loss,
'online_contrastive': online_contrastive_loss,
'infonce': infonce_loss,
# reranker
'reranker': reranker_loss,
'generative_reranker': generative_reranker_loss,
'listwise_reranker': listwise_reranker_loss,
'listwise_generative_reranker': listwise_generative_reranker_loss,
}


def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
if loss_type is None:
return None
return LOSS_MAPPING[loss_type]['loss_func']
return loss_mapping[loss_type]
35 changes: 17 additions & 18 deletions swift/plugin/loss_scale/loss_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@


class LossScale:
# Indicates whether loss_scale contains only 0 and 1.
# If set to True, loss_scale will be replaced by labels to stay compatible with
# acceleration techniques such as liger_kernel.
# If set to False, an additional 'loss_scale' key will be stored and the
# corresponding loss function will be used.
is_binary = False
loss_scale_config = None # path

def _set_keep_loss_scale(self):
self.keep_loss_scale = False
if self.loss_scale_map is None:
return
res = set()
for v in self.loss_scale_map.values():
res.update(v)
if len(res - {0., 1.}) > 0:
self.keep_loss_scale = True

def __init__(self):
if self.loss_scale_config is not None:
path = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -30,14 +26,9 @@ def __init__(self):
self.loss_scale_map = json.load(json_file)
else:
self.loss_scale_map = None
self._set_keep_loss_scale()

def get_loss_scale(self,
context: str,
context_type: ContextType,
is_last_round: bool,
*,
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool,
**kwargs) -> Tuple[List[str], List[float]]:
"""Calculate loss scale

Args:
Expand Down Expand Up @@ -80,7 +71,12 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me
return res_context_list, res_loss_scale


class DefaultLossScale(LossScale):
is_binary = True


class LastRoundLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
if context_type == ContextType.RESPONSE:
Expand Down Expand Up @@ -129,17 +125,20 @@ class AlphaUmiLossScale(REACTLossScale):


class TrainAllLossScale(LossScale):
is_binary = True

def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
return [context], [1.]


class IgnoreEmptyThink(REACTLossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True


class LastRoundWithIgnoreEmptyThink(LossScale):
loss_scale_config = 'ignore_empty_think.json'
is_binary = True

def get_loss_scale(self,
context: str,
Expand All @@ -159,7 +158,7 @@ def get_loss_scale(self,
# Add your loss scale here, use --loss_scale xxx to train
loss_scale_map = {
'last_round': LastRoundLossScale,
'default': LossScale,
'default': DefaultLossScale,
'all': TrainAllLossScale,
'ignore_empty_think': IgnoreEmptyThink,
'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,
Expand Down
Loading
Loading