Skip to content

Commit 0071e07

Browse files
authored
[plugin] refactor loss_type/loss_scale (#5337)
1 parent d8bdeeb commit 0071e07

File tree

13 files changed

+95
-139
lines changed

13 files changed

+95
-139
lines changed

docs/source/Customization/插件化.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ example在[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin
3232
SWIFT支持在plugin中定制loss。如果不使用这个能力,默认会使用交叉熵Loss(CE Loss)。开发者可以在这个文件中编写代码,注册后trainer会自动使用你定制的loss方法。
3333
例如在plugin/loss.py中添加下面的代码:
3434
```python
35-
@register_loss_func("custom_loss")
36-
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
35+
def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
3736
# Write your own loss calculating here
3837
return loss
38+
39+
loss_mapping['custom_loss'] = custom_loss_func
3940
```
4041
需要注意的是,loss和trainer训练的任务是强相关的,目前的loss定制针对pt和sft任务,如果是人类对齐任务(例如DPO、PPO等)或分类任务(seq_cls)任务在插件中是无法定制的。
4142

docs/source_en/Customization/Pluginization.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ SWIFT supports customizing the loss function through plugins. If this feature is
3737
For example, adding the following code in `plugin/loss.py`:
3838

3939
```python
40-
@register_loss_func("custom_loss")
41-
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
40+
def custom_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
4241
# Write your own loss calculation here
4342
return loss
43+
44+
loss_mapping['custom_loss'] = custom_loss_func
4445
```
4546

4647
It is important to note that the loss function is strongly related to the training task. Currently, loss customization supports PT and SFT tasks. For human alignment tasks (e.g., DPO, PPO) or classification tasks (seq_cls), loss customization through plugins is not supported.

requirements/framework.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ numpy
1919
openai
2020
oss2
2121
pandas
22-
peft>=0.11,<0.17
22+
peft>=0.11,<0.18
2323
pillow
2424
PyYAML>=5.4
2525
requests

swift/llm/argument/train_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers import Seq2SeqTrainingArguments
77
from transformers.utils.versions import require_version
88

9-
from swift.plugin import LOSS_MAPPING
9+
from swift.plugin import loss_mapping
1010
from swift.trainers import TrainerFactory
1111
from swift.trainers.arguments import TrainArgumentsMixin
1212
from swift.utils import (add_version_to_work_dir, get_device_count, get_logger, get_pai_tensorboard_dir, is_master,
@@ -118,7 +118,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
118118
create_checkpoint_symlink: bool = False
119119

120120
# plugin
121-
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
121+
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(loss_mapping.keys())}'})
122122
metric: Optional[str] = None
123123

124124
# extra

swift/llm/template/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -971,10 +971,6 @@ def _encode_context_list(
971971
tokenizer_kwargs = {}
972972
if loss_scale_list is None:
973973
loss_scale_list = [0.] * len(context_list)
974-
if self.loss_scale.keep_loss_scale:
975-
ignore_loss_scale = False
976-
else:
977-
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
978974
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
979975
if isinstance(context, str):
980976
# tokenizer_kwargs is the returned tokenizer_kwargs,
@@ -987,9 +983,9 @@ def _encode_context_list(
987983
labels += token_list
988984
else:
989985
labels += [-100] * len(token_list)
990-
if not ignore_loss_scale:
986+
if not self.loss_scale.is_binary:
991987
loss_scale.extend([loss_weight] * len(token_list))
992-
if ignore_loss_scale:
988+
if self.loss_scale.is_binary:
993989
loss_scale = None
994990
return input_ids, labels, loss_scale, tokenizer_kwargs
995991

swift/megatron/utils/convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
182182
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
183183
if args.test_convert_precision:
184184
test_convert_precision(hf_model, mg_model, template)
185+
del hf_model
185186
logger.info('Successfully transferred HF model weights to MG model.')
186187
args.save_args()
187188
mg_save_checkpoint(1, [mg_model], None, None, 0)
@@ -228,6 +229,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
228229
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
229230
if args.test_convert_precision:
230231
test_convert_precision(hf_model, mg_model, template)
232+
del mg_model
231233
logger.info('Successfully transferred MG model weights to HF model.')
232234
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
233235
save_checkpoint(

swift/plugin/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
if TYPE_CHECKING:
77
from .callback import extra_callbacks
8-
from .loss import LOSS_MAPPING, get_loss_func
8+
from .loss import loss_mapping, get_loss_func
99
from .loss_scale import loss_scale_map
1010
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
1111
from .optimizer import optimizers_map
@@ -21,7 +21,7 @@
2121
else:
2222
_import_structure = {
2323
'callback': ['extra_callbacks'],
24-
'loss': ['LOSS_MAPPING', 'get_loss_func'],
24+
'loss': ['loss_mapping', 'get_loss_func'],
2525
'loss_scale': ['loss_scale_map'],
2626
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
2727
'optimizer': ['optimizers_map'],

swift/plugin/loss.py

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,78 +15,18 @@
1515
from swift.plugin import MeanMetric
1616

1717

18-
class LossType:
19-
loss_scale = 'loss_scale'
20-
cosine_similarity = 'cosine_similarity'
21-
contrastive = 'contrastive'
22-
online_contrastive = 'online_contrastive'
23-
infonce = 'infonce'
24-
channel_loss = 'channel_loss'
25-
reranker = 'reranker'
26-
generative_reranker = 'generative_reranker'
27-
listwise_reranker = 'listwise_reranker'
28-
listwise_generative_reranker = 'listwise_generative_reranker'
29-
30-
31-
LOSS_MAPPING = {}
32-
33-
34-
def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None):
35-
loss_info = {}
36-
37-
if loss_func is not None:
38-
loss_info['loss_func'] = loss_func
39-
LOSS_MAPPING[loss_type] = loss_info
40-
return
41-
42-
def _register_loss_func(loss_func: Callable) -> Callable:
43-
loss_info['loss_func'] = loss_func
44-
LOSS_MAPPING[loss_type] = loss_info
45-
return loss_func
46-
47-
return _register_loss_func
48-
49-
50-
def ce_loss_func(outputs, labels):
18+
def per_token_loss_func(outputs, labels, **kwargs):
5119
logits = outputs.logits
52-
device = logits.device
53-
# Shift so that tokens < n predict n
54-
shift_logits = logits[..., :-1, :]
55-
shift_labels = labels[..., 1:].to(device)
56-
# Save memory
57-
masks = shift_labels != -100
58-
shift_logits = shift_logits[masks]
59-
shift_labels = shift_labels[masks]
60-
# Flatten the tokens
61-
loss_fct = CrossEntropyLoss(reduction='none')
62-
loss = loss_fct(shift_logits, shift_labels)
63-
return loss, masks
64-
20+
# Upcast to float if we need to compute the loss to avoid potential precision issues
21+
logits = logits.float()
22+
labels = torch.roll(labels, shifts=-1, dims=-1)
6523

66-
# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
67-
@register_loss_func(LossType.loss_scale)
68-
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
69-
"""Loss func
70-
71-
Args:
72-
outputs: The model outputs
73-
labels: The labels
74-
loss_scale: The loss scale
75-
num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.
76-
77-
Returns:
78-
79-
"""
80-
loss, masks = ce_loss_func(outputs, labels)
81-
if loss_scale is not None:
82-
shift_scale = loss_scale[..., 1:].to(masks.device)
83-
shift_scale = shift_scale[masks]
84-
loss = (shift_scale * loss)
85-
if num_items_in_batch is None:
86-
loss = loss.mean()
87-
else:
88-
# compat transformers>=4.46
89-
loss = loss.sum() / num_items_in_batch
24+
# Flatten the tokens
25+
logits = logits.view(-1, logits.shape[-1])
26+
labels = labels.view(-1)
27+
# Enable model parallelism
28+
labels = labels.to(logits.device)
29+
loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none')
9030
return loss
9131

9232

@@ -117,7 +57,6 @@ class SiameseDistanceMetric(Enum):
11757
COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa
11858

11959

120-
@register_loss_func(LossType.cosine_similarity)
12160
def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
12261
cos_score_transformation = nn.Identity()
12362
loss_fct = MSELoss()
@@ -126,7 +65,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=
12665
return loss_fct(output, labels.to(output.dtype).view(-1))
12766

12867

129-
@register_loss_func(LossType.contrastive)
13068
def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
13169
sentence1, sentence2 = _parse_pair_sentence(outputs)
13270
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
@@ -390,7 +328,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
390328
return split_tensors
391329

392330

393-
@register_loss_func(LossType.infonce)
394331
def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
395332
temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature
396333
# calculate CE across the batch, meaning all samples will be negative except the matching positive
@@ -491,7 +428,6 @@ def mask_fake_negative(sim_matrix, sim_labels):
491428
return loss
492429

493430

494-
@register_loss_func(LossType.online_contrastive)
495431
def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
496432
sentence1, sentence2 = _parse_pair_sentence(outputs)
497433
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
@@ -510,13 +446,13 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
510446
return loss
511447

512448

513-
@register_loss_func(LossType.channel_loss)
514449
def channel_loss_func(outputs,
515450
labels,
516451
num_items_in_batch=None,
517452
sample_channels=None,
518453
trainer=None,
519454
position_ids=None) -> torch.Tensor:
455+
# Note: loss_scale is not supported at the moment.
520456
channels = trainer.args.channels
521457
assert channels is not None, 'Please pass --channels as a hyperparameter.'
522458
assert sample_channels is not None, 'Data does not have channel field.'
@@ -583,7 +519,6 @@ def channel_loss_func(outputs,
583519
else total_loss / (total_tokens.float() + 1e-12)
584520

585521

586-
@register_loss_func(LossType.reranker)
587522
def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
588523
logits = outputs.logits
589524
logits = logits.squeeze(1)
@@ -593,7 +528,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) ->
593528
return loss
594529

595530

596-
@register_loss_func(LossType.generative_reranker)
597531
def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, trainer=None) -> torch.Tensor:
598532
"""
599533
Generative reranker loss function.
@@ -649,7 +583,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc
649583
return loss
650584

651585

652-
@register_loss_func(LossType.listwise_reranker)
653586
def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
654587
"""
655588
List-wise reranker loss function.
@@ -739,7 +672,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
739672
return total_loss / num_groups
740673

741674

742-
@register_loss_func(LossType.listwise_generative_reranker)
743675
def listwise_generative_reranker_loss(outputs,
744676
labels,
745677
loss_scale=None,
@@ -863,7 +795,23 @@ def listwise_generative_reranker_loss(outputs,
863795
return total_loss / num_groups
864796

865797

798+
loss_mapping = {
799+
'per_token_cross_entropy': per_token_loss_func,
800+
'channel_loss': channel_loss_func,
801+
# embedding
802+
'cosine_similarity': cosine_similarity_func,
803+
'contrastive': contrastive_loss,
804+
'online_contrastive': online_contrastive_loss,
805+
'infonce': infonce_loss,
806+
# reranker
807+
'reranker': reranker_loss,
808+
'generative_reranker': generative_reranker_loss,
809+
'listwise_reranker': listwise_reranker_loss,
810+
'listwise_generative_reranker': listwise_generative_reranker_loss,
811+
}
812+
813+
866814
def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
867815
if loss_type is None:
868816
return None
869-
return LOSS_MAPPING[loss_type]['loss_func']
817+
return loss_mapping[loss_type]

swift/plugin/loss_scale/loss_scale.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,14 @@
1010

1111

1212
class LossScale:
13+
# Indicates whether loss_scale contains only 0 and 1.
14+
# If set to True, loss_scale will be replaced by labels to stay compatible with
15+
# acceleration techniques such as liger_kernel.
16+
# If set to False, an additional 'loss_scale' key will be stored and the
17+
# corresponding loss function will be used.
18+
is_binary = False
1319
loss_scale_config = None # path
1420

15-
def _set_keep_loss_scale(self):
16-
self.keep_loss_scale = False
17-
if self.loss_scale_map is None:
18-
return
19-
res = set()
20-
for v in self.loss_scale_map.values():
21-
res.update(v)
22-
if len(res - {0., 1.}) > 0:
23-
self.keep_loss_scale = True
24-
2521
def __init__(self):
2622
if self.loss_scale_config is not None:
2723
path = os.path.dirname(os.path.abspath(__file__))
@@ -30,14 +26,9 @@ def __init__(self):
3026
self.loss_scale_map = json.load(json_file)
3127
else:
3228
self.loss_scale_map = None
33-
self._set_keep_loss_scale()
3429

35-
def get_loss_scale(self,
36-
context: str,
37-
context_type: ContextType,
38-
is_last_round: bool,
39-
*,
40-
query: Optional[str] = None) -> Tuple[List[str], List[float]]:
30+
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool,
31+
**kwargs) -> Tuple[List[str], List[float]]:
4132
"""Calculate loss scale
4233
4334
Args:
@@ -80,7 +71,12 @@ def __call__(self, context_list: List[str], context_types: List[ContextType], me
8071
return res_context_list, res_loss_scale
8172

8273

74+
class DefaultLossScale(LossScale):
75+
is_binary = True
76+
77+
8378
class LastRoundLossScale(LossScale):
79+
is_binary = True
8480

8581
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
8682
if context_type == ContextType.RESPONSE:
@@ -129,17 +125,20 @@ class AlphaUmiLossScale(REACTLossScale):
129125

130126

131127
class TrainAllLossScale(LossScale):
128+
is_binary = True
132129

133130
def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
134131
return [context], [1.]
135132

136133

137134
class IgnoreEmptyThink(REACTLossScale):
138135
loss_scale_config = 'ignore_empty_think.json'
136+
is_binary = True
139137

140138

141139
class LastRoundWithIgnoreEmptyThink(LossScale):
142140
loss_scale_config = 'ignore_empty_think.json'
141+
is_binary = True
143142

144143
def get_loss_scale(self,
145144
context: str,
@@ -159,7 +158,7 @@ def get_loss_scale(self,
159158
# Add your loss scale here, use --loss_scale xxx to train
160159
loss_scale_map = {
161160
'last_round': LastRoundLossScale,
162-
'default': LossScale,
161+
'default': DefaultLossScale,
163162
'all': TrainAllLossScale,
164163
'ignore_empty_think': IgnoreEmptyThink,
165164
'last_round_with_ignore_empty_think': LastRoundWithIgnoreEmptyThink,

0 commit comments

Comments
 (0)