Skip to content

Commit 13790da

Browse files
authored
Unified FuseLoss (#11057)
1 parent bad7890 commit 13790da

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
from paddle.distributed.fleet.layers.mpu import mp_ops
3131
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
3232

33+
try:
34+
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp
35+
except:
36+
pass
37+
38+
3339
from ...transformers.llama.modeling import (
3440
LlamaPretrainingCriterion as PretrainingCriterion,
3541
)
@@ -446,6 +452,16 @@ def forward(
446452

447453
else:
448454
hidden_states, weight, bias, transpose_y = logits
455+
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
456+
hidden_states = GatherOp.apply(hidden_states)
457+
hidden_states = hidden_states.reshape(
458+
[
459+
input_ids.shape[0],
460+
-1,
461+
hidden_states.shape[-1],
462+
]
463+
)
464+
449465
if use_remove_padding:
450466
input_ids = raw_input_ids
451467
if pad_size > 0:

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,10 @@ def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=Fals
14801480
self.weight.split_axis = 0 if self.transpose_y else 1
14811481

14821482
def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):
1483+
# add this for fused_head_and_loss_fn
1484+
if self.config.use_fused_head_and_loss_fn:
1485+
return hidden_states, self.weight, None, self.transpose_y
1486+
14831487
if self.config.sequence_parallel:
14841488
hidden_states = GatherOp.apply(hidden_states)
14851489
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
@@ -1667,19 +1671,6 @@ def forward(
16671671

16681672
hidden_states = outputs[0]
16691673

1670-
# add this for fused_head_and_loss_fn
1671-
if self.config.use_fused_head_and_loss_fn and self.training:
1672-
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
1673-
hidden_states = GatherOp.apply(hidden_states)
1674-
hidden_states = hidden_states.reshape(
1675-
[
1676-
batch_size,
1677-
-1,
1678-
hidden_states.shape[-1],
1679-
]
1680-
)
1681-
return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y
1682-
16831674
# if labels is None,means we need full output, instead of tensor_parallel_output
16841675
# tensor_parallel_output is together with ParallelCrossEntropy
16851676
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1

paddlenlp/transformers/qwen3/modeling.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070

7171
try:
7272
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
73-
GatherOp,
7473
ScatterOp,
7574
mark_as_sequence_parallel_parameter,
7675
)
@@ -1139,19 +1138,6 @@ def forward(
11391138

11401139
hidden_states = outputs[0]
11411140

1142-
# add this for fused_head_and_loss_fn
1143-
if self.config.use_fused_head_and_loss_fn and self.training:
1144-
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
1145-
hidden_states = GatherOp.apply(hidden_states)
1146-
hidden_states = hidden_states.reshape(
1147-
[
1148-
batch_size,
1149-
-1,
1150-
hidden_states.shape[-1],
1151-
]
1152-
)
1153-
return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y
1154-
11551141
# if labels is None,means we need full output, instead of tensor_parallel_output
11561142
# tensor_parallel_output is together with ParallelCrossEntropy
11571143
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1

0 commit comments

Comments
 (0)