11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Optional
15
14
16
15
import paddle
17
16
import paddle .nn as nn
18
17
import paddle .nn .functional as F
19
- from paddle .distributed .fleet .utils .sequence_parallel_utils import AllGatherOp
18
+ from paddle .distributed .fleet .utils .sequence_parallel_utils import GatherOp
20
19
21
20
from ...transformers .model_outputs import CausalLMOutputWithPast
22
21
from ...transformers .sequence_parallel_utils import (
@@ -57,10 +56,12 @@ def dpo_logps(
57
56
bias = lm_head_bias
58
57
transpose_y = self .tie_word_embeddings
59
58
labels = chosen_labels + rejected_labels
59
+ ignore_index = kwargs .pop ("ignore_index" , 0 ) # default is 0
60
+
60
61
# drop ignored index token
61
62
if self .use_filtered_label_loss :
62
63
if self .config .tensor_parallel_degree > 1 and self .config .sequence_parallel and logits is None :
63
- labels , sparse_tgt_idx = sequence_parallel_sparse_mask_labels (labels , 0 )
64
+ labels , sparse_tgt_idx = sequence_parallel_sparse_mask_labels (labels , ignore_index )
64
65
65
66
if hidden_states is not None :
66
67
hidden_states = paddle .gather (hidden_states , sparse_tgt_idx , axis = 0 )
@@ -77,8 +78,15 @@ def dpo_logps(
77
78
if logits is not None :
78
79
logits = paddle .gather (logits , sparse_tgt_idx , axis = 1 )
79
80
else :
80
- if hidden_states is not None :
81
- hidden_states = AllGatherOp .apply (hidden_states )
81
+ if self .config .tensor_parallel_degree > 1 and self .config .sequence_parallel and hidden_states is not None :
82
+ hidden_states = GatherOp .apply (hidden_states )
83
+ hidden_states = hidden_states .reshape (
84
+ [
85
+ - 1 ,
86
+ self .config .max_sequence_length ,
87
+ hidden_states .shape [- 1 ],
88
+ ]
89
+ )
82
90
83
91
# bsz,seq_len,hidden_size or seq_len,hidden_size
84
92
seq_len = labels .shape [1 ] if labels .ndim == 2 else labels .shape [0 ]
@@ -97,7 +105,7 @@ def dpo_logps(
97
105
False , # fused_linear
98
106
self .loss_subbatch_sequence_length ,
99
107
return_token_loss = True ,
100
- ignore_index = 0 ,
108
+ ignore_index = ignore_index ,
101
109
)
102
110
per_token_logps = per_token_logps .reshape ([1 , per_token_logps .shape [- 1 ], 1 ])
103
111
else :
@@ -109,7 +117,6 @@ def dpo_logps(
109
117
transpose_y = transpose_y ,
110
118
tensor_parallel_output = self .config .tensor_parallel_output ,
111
119
)
112
-
113
120
if isinstance (logits , tuple ):
114
121
logits = logits [0 ]
115
122
elif isinstance (logits , CausalLMOutputWithPast ):
@@ -129,14 +136,15 @@ def dpo_logps(
129
136
1 ,
130
137
)
131
138
132
- per_token_logps = sb_loss_func (logits , labels .unsqueeze (- 1 ))
139
+ per_token_logps = - sb_loss_func (logits , labels .unsqueeze (- 1 ))
133
140
else :
134
- per_token_logps = self .loss_func (logits , labels .unsqueeze (- 1 ))
141
+ per_token_logps = - self .loss_func (logits , labels .unsqueeze (- 1 ))
135
142
136
143
if len (response_indexs .shape ) == 3 :
137
144
response_indexs = response_indexs [0 ]
138
145
139
146
offset = 1 if self .ignore_eos_token else 0
147
+
140
148
if self .use_filtered_label_loss :
141
149
chosen_logps = paddle .stack (
142
150
[
@@ -146,6 +154,8 @@ def dpo_logps(
146
154
paddle .arange (response_index [1 ], response_index [2 ], dtype = paddle .int32 ),
147
155
axis = 0 ,
148
156
).sum ()
157
+ if response_index [3 ] != 0
158
+ else paddle .to_tensor (100.0 )
149
159
)
150
160
for response_index in response_indexs
151
161
],
@@ -159,6 +169,8 @@ def dpo_logps(
159
169
paddle .arange (response_index [2 ] + offset , response_index [3 ], dtype = paddle .int32 ),
160
170
axis = 0 ,
161
171
).sum ()
172
+ if response_index [3 ] != 0
173
+ else paddle .to_tensor (100.0 )
162
174
)
163
175
for response_index in response_indexs
164
176
],
@@ -173,6 +185,8 @@ def dpo_logps(
173
185
paddle .arange (response_index [1 ], response_index [2 ], dtype = paddle .int32 ),
174
186
axis = 0 ,
175
187
).sum ()
188
+ if response_index [3 ] != 0
189
+ else paddle .to_tensor (100.0 )
176
190
)
177
191
for response_index in response_indexs
178
192
],
@@ -186,6 +200,8 @@ def dpo_logps(
186
200
paddle .arange (response_index [2 ] + offset , response_index [3 ], dtype = paddle .int32 ),
187
201
axis = 0 ,
188
202
).sum ()
203
+ if response_index [3 ] != 0
204
+ else paddle .to_tensor (100.0 )
189
205
)
190
206
for response_index in response_indexs
191
207
],
@@ -194,22 +210,36 @@ def dpo_logps(
194
210
195
211
sft_loss = - chosen_logps .sum () / (chosen_labels != 0 ).sum ()
196
212
if average_log_prob :
197
- chosen_response_length = response_indexs [:, 2 ] - response_indexs [:, 1 ] - offset
213
+ chosen_response_length = response_indexs [:, 2 ] - response_indexs [:, 1 ]
198
214
rejected_response_length = response_indexs [:, 3 ] - response_indexs [:, 2 ]
199
215
chosen_logps /= chosen_response_length .astype ("float32" )
200
216
rejected_logps /= rejected_response_length .astype ("float32" )
217
+ elif self .dpo_config .normalize_logps :
218
+ avg_response_length = (response_indexs [:, 3 ] - response_indexs [:, 1 ]) / 2
219
+ chosen_response_length = response_indexs [:, 2 ] - response_indexs [:, 1 ]
220
+ rejected_response_length = response_indexs [:, 3 ] - response_indexs [:, 2 ]
221
+ chosen_logps *= avg_response_length / chosen_response_length .astype ("float32" )
222
+ rejected_logps *= avg_response_length / rejected_response_length .astype ("float32" )
201
223
return chosen_logps , rejected_logps , sft_loss * self .dpo_config .sft_loss_ratio
202
224
203
225
204
226
def cal_dpo_loss (
205
- self , policy_chosen_logps , policy_rejected_logps , reference_chosen_logps , reference_rejected_logps , ** kwargs
227
+ self ,
228
+ policy_chosen_logps ,
229
+ policy_rejected_logps ,
230
+ reference_chosen_logps ,
231
+ reference_rejected_logps ,
232
+ score_deltas ,
233
+ ** kwargs
206
234
):
207
235
"""DPO Loss"""
208
236
pi_logratios = policy_chosen_logps - policy_rejected_logps
209
237
ref_logratios = reference_chosen_logps - reference_rejected_logps
210
238
logits = pi_logratios - ref_logratios
211
239
212
240
if self .dpo_config .loss_type == "sigmoid" :
241
+ if self .dpo_config .offset_alpha > 0 and score_deltas is not None :
242
+ logits = logits - self .dpo_config .offset_alpha / self .dpo_config .beta * paddle .log (score_deltas + 1e-6 )
213
243
loss = (
214
244
- F .log_sigmoid (self .dpo_config .beta * logits ) * (1 - self .dpo_config .label_smoothing )
215
245
- F .log_sigmoid (- self .dpo_config .beta * logits ) * self .dpo_config .label_smoothing
@@ -282,21 +312,31 @@ def cal_dpo_loss(
282
312
283
313
284
314
def dpo_loss_forward (
285
- self : nn .Layer , logits : paddle .Tensor , labels : paddle .Tensor , loss_mask : Optional [ paddle .Tensor ] = None , ** kwargs
315
+ self : nn .Layer , logits : paddle .Tensor , labels : paddle .Tensor , loss_mask : paddle .Tensor = None , ** kwargs
286
316
):
287
317
# unpack logtis and labels
288
318
logits , labels , hidden_states , lm_head_weight , lm_head_bias , transpose_y = dpo_preprocess_inputs (
289
319
self , logits , labels
290
320
)
291
321
292
- (
293
- chosen_labels ,
294
- rejected_labels ,
295
- response_indexs ,
296
- score_deltas ,
297
- reference_chosen_logps ,
298
- reference_rejected_logps ,
299
- ) = labels
322
+ if self .dpo_config .offset_alpha > 0 or len (labels ) == 6 :
323
+ (
324
+ chosen_labels ,
325
+ rejected_labels ,
326
+ response_indexs ,
327
+ score_deltas ,
328
+ reference_chosen_logps ,
329
+ reference_rejected_logps ,
330
+ ) = labels
331
+ else :
332
+ (
333
+ chosen_labels ,
334
+ rejected_labels ,
335
+ response_indexs ,
336
+ reference_chosen_logps ,
337
+ reference_rejected_logps ,
338
+ ) = labels
339
+ score_deltas = None
300
340
301
341
average_log_prob = False
302
342
if self .dpo_config .loss_type in ["ipo" , "or" , "simpo" ]:
@@ -336,7 +376,12 @@ def dpo_loss_forward(
336
376
** kwargs ,
337
377
)
338
378
dpo_loss = cal_dpo_loss (
339
- self , policy_chosen_logps , policy_rejected_logps , reference_chosen_logps , reference_rejected_logps
379
+ self ,
380
+ policy_chosen_logps ,
381
+ policy_rejected_logps ,
382
+ reference_chosen_logps ,
383
+ reference_rejected_logps ,
384
+ score_deltas ,
340
385
)
341
386
342
387
loss = dpo_loss + sft_loss
0 commit comments