Skip to content

Commit 804acdd

Browse files
authored
fix ernie4_5 moe_layer to allgather (#2633)
1 parent 044e765 commit 804acdd

File tree

11 files changed

+189
-82
lines changed

11 files changed

+189
-82
lines changed

examples/alignment/dpo/dpo_argument.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class DPOConfig:
102102
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
103103
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
104104
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
105+
offset_alpha: float = field(default=0.0, metadata={"help": "offset alpha"})
106+
normalize_logps: bool = field(default=True, metadata={"help": "normalize logps"})
105107

106108

107109
@dataclass

paddleformers/nn/criterion/dpo_loss.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
1514

1615
import paddle
1716
import paddle.nn as nn
1817
import paddle.nn.functional as F
19-
from paddle.distributed.fleet.utils.sequence_parallel_utils import AllGatherOp
18+
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp
2019

2120
from ...transformers.model_outputs import CausalLMOutputWithPast
2221
from ...transformers.sequence_parallel_utils import (
@@ -57,10 +56,12 @@ def dpo_logps(
5756
bias = lm_head_bias
5857
transpose_y = self.tie_word_embeddings
5958
labels = chosen_labels + rejected_labels
59+
ignore_index = kwargs.pop("ignore_index", 0) # default is 0
60+
6061
# drop ignored index token
6162
if self.use_filtered_label_loss:
6263
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel and logits is None:
63-
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, 0)
64+
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, ignore_index)
6465

6566
if hidden_states is not None:
6667
hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0)
@@ -77,8 +78,15 @@ def dpo_logps(
7778
if logits is not None:
7879
logits = paddle.gather(logits, sparse_tgt_idx, axis=1)
7980
else:
80-
if hidden_states is not None:
81-
hidden_states = AllGatherOp.apply(hidden_states)
81+
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel and hidden_states is not None:
82+
hidden_states = GatherOp.apply(hidden_states)
83+
hidden_states = hidden_states.reshape(
84+
[
85+
-1,
86+
self.config.max_sequence_length,
87+
hidden_states.shape[-1],
88+
]
89+
)
8290

8391
# bsz,seq_len,hidden_size or seq_len,hidden_size
8492
seq_len = labels.shape[1] if labels.ndim == 2 else labels.shape[0]
@@ -97,7 +105,7 @@ def dpo_logps(
97105
False, # fused_linear
98106
self.loss_subbatch_sequence_length,
99107
return_token_loss=True,
100-
ignore_index=0,
108+
ignore_index=ignore_index,
101109
)
102110
per_token_logps = per_token_logps.reshape([1, per_token_logps.shape[-1], 1])
103111
else:
@@ -109,7 +117,6 @@ def dpo_logps(
109117
transpose_y=transpose_y,
110118
tensor_parallel_output=self.config.tensor_parallel_output,
111119
)
112-
113120
if isinstance(logits, tuple):
114121
logits = logits[0]
115122
elif isinstance(logits, CausalLMOutputWithPast):
@@ -129,14 +136,15 @@ def dpo_logps(
129136
1,
130137
)
131138

132-
per_token_logps = sb_loss_func(logits, labels.unsqueeze(-1))
139+
per_token_logps = -sb_loss_func(logits, labels.unsqueeze(-1))
133140
else:
134-
per_token_logps = self.loss_func(logits, labels.unsqueeze(-1))
141+
per_token_logps = -self.loss_func(logits, labels.unsqueeze(-1))
135142

136143
if len(response_indexs.shape) == 3:
137144
response_indexs = response_indexs[0]
138145

139146
offset = 1 if self.ignore_eos_token else 0
147+
140148
if self.use_filtered_label_loss:
141149
chosen_logps = paddle.stack(
142150
[
@@ -146,6 +154,8 @@ def dpo_logps(
146154
paddle.arange(response_index[1], response_index[2], dtype=paddle.int32),
147155
axis=0,
148156
).sum()
157+
if response_index[3] != 0
158+
else paddle.to_tensor(100.0)
149159
)
150160
for response_index in response_indexs
151161
],
@@ -159,6 +169,8 @@ def dpo_logps(
159169
paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32),
160170
axis=0,
161171
).sum()
172+
if response_index[3] != 0
173+
else paddle.to_tensor(100.0)
162174
)
163175
for response_index in response_indexs
164176
],
@@ -173,6 +185,8 @@ def dpo_logps(
173185
paddle.arange(response_index[1], response_index[2], dtype=paddle.int32),
174186
axis=0,
175187
).sum()
188+
if response_index[3] != 0
189+
else paddle.to_tensor(100.0)
176190
)
177191
for response_index in response_indexs
178192
],
@@ -186,6 +200,8 @@ def dpo_logps(
186200
paddle.arange(response_index[2] + offset, response_index[3], dtype=paddle.int32),
187201
axis=0,
188202
).sum()
203+
if response_index[3] != 0
204+
else paddle.to_tensor(100.0)
189205
)
190206
for response_index in response_indexs
191207
],
@@ -194,22 +210,36 @@ def dpo_logps(
194210

195211
sft_loss = -chosen_logps.sum() / (chosen_labels != 0).sum()
196212
if average_log_prob:
197-
chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1] - offset
213+
chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1]
198214
rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2]
199215
chosen_logps /= chosen_response_length.astype("float32")
200216
rejected_logps /= rejected_response_length.astype("float32")
217+
elif self.dpo_config.normalize_logps:
218+
avg_response_length = (response_indexs[:, 3] - response_indexs[:, 1]) / 2
219+
chosen_response_length = response_indexs[:, 2] - response_indexs[:, 1]
220+
rejected_response_length = response_indexs[:, 3] - response_indexs[:, 2]
221+
chosen_logps *= avg_response_length / chosen_response_length.astype("float32")
222+
rejected_logps *= avg_response_length / rejected_response_length.astype("float32")
201223
return chosen_logps, rejected_logps, sft_loss * self.dpo_config.sft_loss_ratio
202224

203225

204226
def cal_dpo_loss(
205-
self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **kwargs
227+
self,
228+
policy_chosen_logps,
229+
policy_rejected_logps,
230+
reference_chosen_logps,
231+
reference_rejected_logps,
232+
score_deltas,
233+
**kwargs
206234
):
207235
"""DPO Loss"""
208236
pi_logratios = policy_chosen_logps - policy_rejected_logps
209237
ref_logratios = reference_chosen_logps - reference_rejected_logps
210238
logits = pi_logratios - ref_logratios
211239

212240
if self.dpo_config.loss_type == "sigmoid":
241+
if self.dpo_config.offset_alpha > 0 and score_deltas is not None:
242+
logits = logits - self.dpo_config.offset_alpha / self.dpo_config.beta * paddle.log(score_deltas + 1e-6)
213243
loss = (
214244
-F.log_sigmoid(self.dpo_config.beta * logits) * (1 - self.dpo_config.label_smoothing)
215245
- F.log_sigmoid(-self.dpo_config.beta * logits) * self.dpo_config.label_smoothing
@@ -282,21 +312,31 @@ def cal_dpo_loss(
282312

283313

284314
def dpo_loss_forward(
285-
self: nn.Layer, logits: paddle.Tensor, labels: paddle.Tensor, loss_mask: Optional[paddle.Tensor] = None, **kwargs
315+
self: nn.Layer, logits: paddle.Tensor, labels: paddle.Tensor, loss_mask: paddle.Tensor = None, **kwargs
286316
):
287317
# unpack logtis and labels
288318
logits, labels, hidden_states, lm_head_weight, lm_head_bias, transpose_y = dpo_preprocess_inputs(
289319
self, logits, labels
290320
)
291321

292-
(
293-
chosen_labels,
294-
rejected_labels,
295-
response_indexs,
296-
score_deltas,
297-
reference_chosen_logps,
298-
reference_rejected_logps,
299-
) = labels
322+
if self.dpo_config.offset_alpha > 0 or len(labels) == 6:
323+
(
324+
chosen_labels,
325+
rejected_labels,
326+
response_indexs,
327+
score_deltas,
328+
reference_chosen_logps,
329+
reference_rejected_logps,
330+
) = labels
331+
else:
332+
(
333+
chosen_labels,
334+
rejected_labels,
335+
response_indexs,
336+
reference_chosen_logps,
337+
reference_rejected_logps,
338+
) = labels
339+
score_deltas = None
300340

301341
average_log_prob = False
302342
if self.dpo_config.loss_type in ["ipo", "or", "simpo"]:
@@ -336,7 +376,12 @@ def dpo_loss_forward(
336376
**kwargs,
337377
)
338378
dpo_loss = cal_dpo_loss(
339-
self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
379+
self,
380+
policy_chosen_logps,
381+
policy_rejected_logps,
382+
reference_chosen_logps,
383+
reference_rejected_logps,
384+
score_deltas,
340385
)
341386

342387
loss = dpo_loss + sft_loss

paddleformers/nn/criterion/sft_loss.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import paddle
1717
import paddle.nn as nn
18+
from paddle.distributed.fleet.utils import recompute
1819
from paddle.distributed.fleet.utils.sequence_parallel_utils import AllGatherOp
1920

2021
from ...transformers.sequence_parallel_utils import (
@@ -51,41 +52,14 @@ def sft_postprocess_loss(self, masked_lm_loss, labels, loss_mask, **kwargs):
5152
return loss, loss_sum
5253

5354

54-
def sft_loss_forward(
55-
self: nn.Layer,
56-
logits: Union[paddle.Tensor, Tuple[paddle.Tensor]],
57-
labels: Union[paddle.Tensor, Tuple[paddle.Tensor]],
58-
loss_mask: paddle.Tensor = None,
59-
**kwargs
60-
):
61-
logits, labels, hidden_states, lm_head_weight, lm_head_bias, transpose_y = sft_preprocess_inputs(
62-
self, logits, labels
63-
)
64-
if self.use_filtered_label_loss:
65-
if self.tensor_parallel and self.sequence_parallel and logits is None:
66-
masked_lm_labels, sparse_label_idx = sequence_parallel_sparse_mask_labels(labels, self.ignored_index)
67-
sparse_label_idx = sparse_label_idx.reshape([-1, 1])
68-
if hidden_states is not None:
69-
hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0)
70-
hidden_states = AllGatherVarlenOp.apply(hidden_states)
71-
else:
72-
masked_lm_labels = labels.flatten()
73-
sparse_label_idx = paddle.nonzero(masked_lm_labels != self.ignored_index).flatten()
74-
masked_lm_labels = paddle.take_along_axis(masked_lm_labels, sparse_label_idx, axis=0)
75-
if hidden_states is not None:
76-
hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
77-
hidden_states = paddle.take_along_axis(hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0)
78-
if logits is not None:
79-
logits = paddle.gather(logits, sparse_label_idx, axis=1)
80-
labels = masked_lm_labels
81-
else:
82-
if self.sequence_parallel:
83-
if hidden_states is not None:
84-
hidden_states = AllGatherOp.apply(hidden_states)
55+
def loss_impl(self, logits, labels):
56+
logits = logits.cast("float32")
57+
loss = self.loss_func(logits, labels)
58+
return loss
8559

86-
masked_lm_labels = labels
87-
# bsz,seq_len,hidden_size or seq_len,hidden_size
88-
seq_len = masked_lm_labels.shape[1] if masked_lm_labels.ndim == 2 else masked_lm_labels.shape[0]
60+
61+
def sft_calculate_loss(self, logits, hidden_states, lm_head_weight, lm_head_bias, labels, loss_mask, transpose_y):
62+
seq_len = labels.shape[1] if labels.ndim == 2 else labels.shape[0]
8963
if self.use_fused_head_and_loss_fn and self.use_subbatch and seq_len > self.loss_subbatch_sequence_length:
9064
masked_lm_loss = fused_head_and_loss_fn(
9165
hidden_states,
@@ -123,7 +97,6 @@ def sft_loss_forward(
12397
f" {logits.shape[-1]}, {self.config.vocab_size}"
12498
)
12599

126-
logits = logits.cast("float32")
127100
if logits.dim() == 2 and labels.dim() == 2:
128101
logits = logits.unsqueeze(0)
129102
elif logits.dim() == 3 and labels.dim() == 1:
@@ -133,16 +106,77 @@ def sft_loss_forward(
133106
# labels: bsz seq_len vocab_size
134107
if self.use_subbatch and seq_len > self.loss_subbatch_sequence_length:
135108
sb_loss_func = subbatch(
136-
self.loss_func,
137-
[0, 1],
138-
[1, 1],
139-
self.loss_subbatch_sequence_length,
140-
1,
109+
loss_impl,
110+
arg_idx=[1, 2],
111+
axis=[1, 1],
112+
bs=self.loss_subbatch_sequence_length,
113+
out_idx=1,
141114
)
142-
masked_lm_loss = sb_loss_func(logits, labels.unsqueeze(-1))
115+
masked_lm_loss = sb_loss_func(self, logits, labels.unsqueeze(-1))
143116
else:
144-
masked_lm_loss = self.loss_func(logits, labels.unsqueeze(-1))
145-
loss = sft_postprocess_loss(self, masked_lm_loss, labels, loss_mask, **kwargs)
117+
masked_lm_loss = loss_impl(self, logits, labels.unsqueeze(-1))
118+
119+
masked_lm_loss = sft_postprocess_loss(self, masked_lm_loss, labels, loss_mask)
120+
return masked_lm_loss
121+
122+
123+
def sft_loss_forward(
124+
self: nn.Layer,
125+
logits: Union[paddle.Tensor, Tuple[paddle.Tensor]],
126+
labels: Union[paddle.Tensor, Tuple[paddle.Tensor]],
127+
loss_mask: paddle.Tensor = None,
128+
**kwargs
129+
):
130+
logits, labels, hidden_states, lm_head_weight, lm_head_bias, transpose_y = sft_preprocess_inputs(
131+
self, logits, labels
132+
)
133+
if self.use_filtered_label_loss:
134+
if self.tensor_parallel and self.sequence_parallel and logits is None:
135+
masked_lm_labels, sparse_label_idx = sequence_parallel_sparse_mask_labels(labels, self.ignored_index)
136+
sparse_label_idx = sparse_label_idx.reshape([-1, 1])
137+
if hidden_states is not None:
138+
hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0)
139+
hidden_states = AllGatherVarlenOp.apply(hidden_states)
140+
else:
141+
masked_lm_labels = labels.flatten()
142+
sparse_label_idx = paddle.nonzero(masked_lm_labels != self.ignored_index).flatten()
143+
masked_lm_labels = paddle.take_along_axis(masked_lm_labels, sparse_label_idx, axis=0)
144+
if hidden_states is not None:
145+
hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
146+
hidden_states = paddle.take_along_axis(hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0)
147+
if logits is not None:
148+
logits = paddle.gather(logits, sparse_label_idx, axis=1)
149+
labels = masked_lm_labels
150+
else:
151+
if self.sequence_parallel:
152+
if hidden_states is not None:
153+
hidden_states = AllGatherOp.apply(hidden_states)
154+
155+
masked_lm_labels = labels
156+
# bsz,seq_len,hidden_size or seq_len,hidden_size
157+
if self.config.recompute:
158+
loss = recompute(
159+
sft_calculate_loss,
160+
self,
161+
logits,
162+
hidden_states,
163+
lm_head_weight,
164+
lm_head_bias,
165+
labels,
166+
loss_mask,
167+
transpose_y,
168+
)
169+
else:
170+
loss = sft_calculate_loss(
171+
self,
172+
logits,
173+
hidden_states,
174+
lm_head_weight,
175+
lm_head_bias,
176+
labels,
177+
loss_mask,
178+
transpose_y,
179+
)
146180
return loss
147181

148182

paddleformers/nn/mlp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import paddle
1515
import paddle.nn as nn
1616
from paddle.incubate.nn.functional import swiglu as fused_swiglu
1717

@@ -117,5 +117,9 @@ def forward(self, x):
117117
else:
118118
gate = self.gate_proj(x)
119119
up = self.up_proj(x)
120-
x = self.act_fn(gate) * up
120+
if self.fuse_swiglu:
121+
x = paddle.concat([gate, up], axis=-1)
122+
x = fused_swiglu(x)
123+
else:
124+
x = self.act_fn(gate) * up
121125
return self.down_proj(x)

paddleformers/nn/moe/all_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def forward(
308308

309309
recv_mask_alltoall_out = paddle.cat(recv_mask_alltoall_out, 0)
310310
distributed_input_to_alltoall_out = paddle.maximum(
311-
recv_mask_alltoall_out.cumsum() - 1,
311+
(recv_mask_alltoall_out.cumsum() - 1).astype(recv_mask_alltoall_out.dtype),
312312
paddle.zeros([1], dtype=recv_mask_alltoall_out.dtype),
313313
)
314314
distributed_input_to_alltoall_out = distributed_input_to_alltoall_out.split(alltoall_shape)

0 commit comments

Comments
 (0)