15
15
from swift .plugin import MeanMetric
16
16
17
17
18
- class LossType :
19
- loss_scale = 'loss_scale'
20
- cosine_similarity = 'cosine_similarity'
21
- contrastive = 'contrastive'
22
- online_contrastive = 'online_contrastive'
23
- infonce = 'infonce'
24
- channel_loss = 'channel_loss'
25
- reranker = 'reranker'
26
- generative_reranker = 'generative_reranker'
27
- listwise_reranker = 'listwise_reranker'
28
- listwise_generative_reranker = 'listwise_generative_reranker'
29
-
30
-
31
- LOSS_MAPPING = {}
32
-
33
-
34
- def register_loss_func (loss_type : str , loss_func : Optional [Callable ] = None ):
35
- loss_info = {}
36
-
37
- if loss_func is not None :
38
- loss_info ['loss_func' ] = loss_func
39
- LOSS_MAPPING [loss_type ] = loss_info
40
- return
41
-
42
- def _register_loss_func (loss_func : Callable ) -> Callable :
43
- loss_info ['loss_func' ] = loss_func
44
- LOSS_MAPPING [loss_type ] = loss_info
45
- return loss_func
46
-
47
- return _register_loss_func
48
-
49
-
50
- def ce_loss_func (outputs , labels ):
18
+ def per_token_loss_func (outputs , labels , ** kwargs ):
51
19
logits = outputs .logits
52
- device = logits .device
53
- # Shift so that tokens < n predict n
54
- shift_logits = logits [..., :- 1 , :]
55
- shift_labels = labels [..., 1 :].to (device )
56
- # Save memory
57
- masks = shift_labels != - 100
58
- shift_logits = shift_logits [masks ]
59
- shift_labels = shift_labels [masks ]
60
- # Flatten the tokens
61
- loss_fct = CrossEntropyLoss (reduction = 'none' )
62
- loss = loss_fct (shift_logits , shift_labels )
63
- return loss , masks
64
-
20
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
21
+ logits = logits .float ()
22
+ labels = torch .roll (labels , shifts = - 1 , dims = - 1 )
65
23
66
- # Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
67
- @register_loss_func (LossType .loss_scale )
68
- def loss_scale_func (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
69
- """Loss func
70
-
71
- Args:
72
- outputs: The model outputs
73
- labels: The labels
74
- loss_scale: The loss scale
75
- num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.
76
-
77
- Returns:
78
-
79
- """
80
- loss , masks = ce_loss_func (outputs , labels )
81
- if loss_scale is not None :
82
- shift_scale = loss_scale [..., 1 :].to (masks .device )
83
- shift_scale = shift_scale [masks ]
84
- loss = (shift_scale * loss )
85
- if num_items_in_batch is None :
86
- loss = loss .mean ()
87
- else :
88
- # compat transformers>=4.46
89
- loss = loss .sum () / num_items_in_batch
24
+ # Flatten the tokens
25
+ logits = logits .view (- 1 , logits .shape [- 1 ])
26
+ labels = labels .view (- 1 )
27
+ # Enable model parallelism
28
+ labels = labels .to (logits .device )
29
+ loss = F .cross_entropy (logits , labels , ignore_index = - 100 , reduction = 'none' )
90
30
return loss
91
31
92
32
@@ -117,7 +57,6 @@ class SiameseDistanceMetric(Enum):
117
57
COSINE_DISTANCE = lambda x , y : 1 - F .cosine_similarity (x , y ) # noqa
118
58
119
59
120
- @register_loss_func (LossType .cosine_similarity )
121
60
def cosine_similarity_func (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
122
61
cos_score_transformation = nn .Identity ()
123
62
loss_fct = MSELoss ()
@@ -126,7 +65,6 @@ def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=
126
65
return loss_fct (output , labels .to (output .dtype ).view (- 1 ))
127
66
128
67
129
- @register_loss_func (LossType .contrastive )
130
68
def contrastive_loss (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
131
69
sentence1 , sentence2 = _parse_pair_sentence (outputs )
132
70
distance_metric = SiameseDistanceMetric .COSINE_DISTANCE
@@ -390,7 +328,6 @@ def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
390
328
return split_tensors
391
329
392
330
393
- @register_loss_func (LossType .infonce )
394
331
def infonce_loss (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
395
332
temperature = float (os .environ .get ('INFONCE_TEMPERATURE' , '0.01' )) # temperature
396
333
# calculate CE across the batch, meaning all samples will be negative except the matching positive
@@ -491,7 +428,6 @@ def mask_fake_negative(sim_matrix, sim_labels):
491
428
return loss
492
429
493
430
494
- @register_loss_func (LossType .online_contrastive )
495
431
def online_contrastive_loss (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
496
432
sentence1 , sentence2 = _parse_pair_sentence (outputs )
497
433
distance_metric = SiameseDistanceMetric .COSINE_DISTANCE
@@ -510,13 +446,13 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
510
446
return loss
511
447
512
448
513
- @register_loss_func (LossType .channel_loss )
514
449
def channel_loss_func (outputs ,
515
450
labels ,
516
451
num_items_in_batch = None ,
517
452
sample_channels = None ,
518
453
trainer = None ,
519
454
position_ids = None ) -> torch .Tensor :
455
+ # Note: loss_scale is not supported at the moment.
520
456
channels = trainer .args .channels
521
457
assert channels is not None , 'Please pass --channels as a hyperparameter.'
522
458
assert sample_channels is not None , 'Data does not have channel field.'
@@ -583,7 +519,6 @@ def channel_loss_func(outputs,
583
519
else total_loss / (total_tokens .float () + 1e-12 )
584
520
585
521
586
- @register_loss_func (LossType .reranker )
587
522
def reranker_loss (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
588
523
logits = outputs .logits
589
524
logits = logits .squeeze (1 )
@@ -593,7 +528,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) ->
593
528
return loss
594
529
595
530
596
- @register_loss_func (LossType .generative_reranker )
597
531
def generative_reranker_loss (outputs , labels , loss_scale = None , num_items_in_batch = None , trainer = None ) -> torch .Tensor :
598
532
"""
599
533
Generative reranker loss function.
@@ -649,7 +583,6 @@ def generative_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batc
649
583
return loss
650
584
651
585
652
- @register_loss_func (LossType .listwise_reranker )
653
586
def listwise_reranker_loss (outputs , labels , loss_scale = None , num_items_in_batch = None ) -> torch .Tensor :
654
587
"""
655
588
List-wise reranker loss function.
@@ -739,7 +672,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
739
672
return total_loss / num_groups
740
673
741
674
742
- @register_loss_func (LossType .listwise_generative_reranker )
743
675
def listwise_generative_reranker_loss (outputs ,
744
676
labels ,
745
677
loss_scale = None ,
@@ -863,7 +795,23 @@ def listwise_generative_reranker_loss(outputs,
863
795
return total_loss / num_groups
864
796
865
797
798
+ loss_mapping = {
799
+ 'per_token_cross_entropy' : per_token_loss_func ,
800
+ 'channel_loss' : channel_loss_func ,
801
+ # embedding
802
+ 'cosine_similarity' : cosine_similarity_func ,
803
+ 'contrastive' : contrastive_loss ,
804
+ 'online_contrastive' : online_contrastive_loss ,
805
+ 'infonce' : infonce_loss ,
806
+ # reranker
807
+ 'reranker' : reranker_loss ,
808
+ 'generative_reranker' : generative_reranker_loss ,
809
+ 'listwise_reranker' : listwise_reranker_loss ,
810
+ 'listwise_generative_reranker' : listwise_generative_reranker_loss ,
811
+ }
812
+
813
+
866
814
def get_loss_func (loss_type : Optional [str ]) -> Optional [Callable ]:
867
815
if loss_type is None :
868
816
return None
869
- return LOSS_MAPPING [loss_type ][ 'loss_func' ]
817
+ return loss_mapping [loss_type ]
0 commit comments