Skip to content

Commit c861b78

Browse files
cheng221llbdyiu66
andauthored
Add ernie4 5 moe (#2520)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent c239ca8 commit c861b78

30 files changed

+6800
-454
lines changed

paddleformers/generation/utils.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@
2020

2121
import paddle
2222
import paddle.distributed as dist
23-
import paddle.nn as nn
2423
import paddle.nn.functional as F
2524
from paddle import Tensor
26-
from paddle.common_ops_import import convert_dtype
2725
from paddle.utils import map_structure
2826

29-
from ..transformers.model_outputs import ModelOutput
27+
from ..transformers.model_outputs import CausalLMOutputWithPast, ModelOutput
3028
from ..transformers.utils import get_scale_by_dtype
3129
from ..utils.log import logger
3230
from ..utils.masking_utils import _expand_2d_mask, _make_causal_mask
@@ -493,61 +491,38 @@ def expand_inputs_for_generation(input_ids, expand_size, attention_mask=None, **
493491

494492
@staticmethod
495493
def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
496-
# Update the model inputs during generation.
497-
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
498-
# and they contain pad value, the result vectors updated by this method
499-
# may be different from expected. In this case, you need to rewrite the
500-
# method.
494+
"""
495+
Updates model kwargs for generation.
496+
497+
Args:
498+
outputs (Any): Model outputs.
499+
model_kwargs (dict): Current model kwargs.
500+
is_encoder_decoder (bool): Whether using encoder-decoder architecture.
501501
502+
Returns:
503+
dict: Updated model kwargs.
504+
"""
502505
# update cache
503506
if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor):
504-
model_kwargs["cache"] = outputs[1]
505507
model_kwargs["past_key_values"] = outputs[1]
506508

507-
if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
508-
model_kwargs["cache"] = outputs.past_key_values
509+
if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
509510
model_kwargs["past_key_values"] = outputs.past_key_values
510511

511512
# update token_type_ids with last value
512513
if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None:
513514
token_type_ids = model_kwargs["token_type_ids"]
514515
model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1)
515-
516-
# update position_ids
517-
if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
518-
position_ids = model_kwargs["position_ids"]
519-
model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1)
520-
521-
# update attention_mask
522-
if not is_encoder_decoder and "attention_mask" in model_kwargs:
516+
if not is_encoder_decoder and model_kwargs.get("attention_mask", None) is not None:
517+
# update attention mask
523518
attention_mask = model_kwargs["attention_mask"]
524-
# nn.Pad2D don't support the data type `bool`
525-
if convert_dtype(attention_mask.dtype) == "bool":
526-
attention_mask = paddle.cast(attention_mask, "int64")
527-
if len(attention_mask.shape) == 4:
528-
cur_device = paddle.get_device()
529-
if cur_device.split(":")[0] == "npu":
530-
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(attention_mask)
531-
attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
532-
else:
533-
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
534-
attention_mask = nn.Pad2D([0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False))(
535-
attention_mask
536-
)
537-
538-
dtype = convert_dtype(attention_mask.dtype)
539-
if "int" in dtype:
540-
attention_mask[:, :, -1, -1] = 1
541-
elif "float" in dtype:
542-
attention_mask[:, :, -1, -1] = 0.0
543-
else:
544-
raise ValueError("The data type of input `attention_mask` must " "be bool, int or float")
545-
else:
546-
attention_mask = paddle.concat(
547-
[attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype="int64")], axis=-1
548-
)
549-
model_kwargs["attention_mask"] = attention_mask
550-
519+
model_kwargs["attention_mask"] = paddle.concat(
520+
[
521+
attention_mask,
522+
paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype),
523+
],
524+
axis=-1,
525+
)
551526
# update role_ids
552527
if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
553528
role_ids = model_kwargs["role_ids"]
@@ -611,11 +586,63 @@ def get_decoder_start_token_id(self, decoder_start_token_id=None, bos_token_id=N
611586
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
612587
)
613588

614-
def prepare_inputs_for_generation(self, input_ids, **kwargs):
615-
# Implement in subclasses for custom behavior to prepare inputs in the
616-
# generate method.
589+
def prepare_inputs_for_generation(
590+
self,
591+
input_ids,
592+
use_cache=True,
593+
past_key_values=None,
594+
inputs_embeds=None,
595+
**kwargs,
596+
):
597+
"""Prepares model inputs for generation in PaddlePaddle models.
598+
599+
Args:
600+
input_ids (paddle.Tensor):
601+
The input token IDs with shape [batch_size, sequence_length].
602+
use_cache (bool, optional):
603+
Whether to use cached key-value states for faster generation.
604+
Defaults to False.
605+
past_key_values (Optional[Tuple[paddle.Tensor]]):
606+
Cached past key-value states from previous generation steps.
607+
If provided, the input_ids will be truncated to only keep the last token.
608+
inputs_embeds (Optional[paddle.Tensor]):
609+
Precomputed embeddings instead of token IDs.
610+
Only used in the first generation step when past_key_values is None.
611+
**kwargs:
612+
Additional keyword arguments including:
613+
- attention_mask (paddle.Tensor): Attention mask tensor
614+
615+
Returns:
616+
Dict[str, Union[paddle.Tensor, bool, Dict]]:
617+
A dictionary containing:
618+
- "input_ids" or "inputs_embeds": The main input tensors
619+
- "past_key_values": The cached key-value states
620+
- "use_cache": Flag indicating whether to use caching
621+
- "attention_mask": The attention mask tensor (if provided)
622+
- "return_dict": Always set to True for consistent output format
623+
624+
"""
625+
if past_key_values:
626+
input_ids = input_ids[:, -1:]
627+
628+
attention_mask = kwargs.get("attention_mask", None)
629+
630+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
631+
if inputs_embeds is not None and past_key_values is None:
632+
model_inputs = {"inputs_embeds": inputs_embeds}
633+
else:
634+
model_inputs = {"input_ids": input_ids}
635+
636+
model_inputs.update(
637+
{
638+
"past_key_values": past_key_values,
639+
"use_cache": use_cache,
640+
"attention_mask": attention_mask,
641+
"return_dict": True,
642+
}
643+
)
617644

618-
return {"input_ids": input_ids}
645+
return model_inputs
619646

620647
def adjust_logits_during_generation(self, logits):
621648
# Implement in subclasses for custom behavior to adjust the logits in

paddleformers/nn/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,40 @@
3535
"sft_postprocess_loss",
3636
"sft_loss_forward",
3737
],
38+
"moe.abstract": ["MOELayerBase"],
39+
"moe.all_gather": ["allgather_async", "reduce_scatter_async", "AlltoAllSmart", "AllGatherAsync"],
40+
"moe.all_to_all": ["AlltoAll", "AlltoAllAsync"],
41+
"moe.moe_allgather_layer": ["ReshardCombineWeight", "MOEAllGatherLayerV2"],
42+
"moe.moe_alltoall_layer": ["GateCombine", "combining"],
43+
"moe.moe_block": ["create_moe_block", "MoEStatics"],
44+
"moe.top_gate": [
45+
"masked_fill",
46+
"compute_optimal_transport",
47+
"cast_if_needed",
48+
"FusedGateDetachMatmul",
49+
"gate_detach_matmul",
50+
"TopKGate",
51+
],
52+
"moe.utils": [
53+
"ReduceScatterGroupOp",
54+
"AllGatherGroupOp",
55+
"get_async_loader",
56+
"hack_offload_wait",
57+
"all_gather_group",
58+
"reduce_scatter_group",
59+
"detach_and_requires_grad_",
60+
"FakeClone",
61+
"manual_backward",
62+
"_parse_moe_group",
63+
],
3864
"activation": ["ACT2FN", "ClassInstantier", "ACT2CLS"],
3965
"embedding": ["Embedding"],
4066
"general": ["GeneralInterface"],
4167
"linear": ["Linear"],
4268
"lm_head": ["LMHead"],
4369
"mlp": ["MLP"],
4470
"norm": ["Norm", "LayerNorm", "RMSNorm"],
71+
"pp_model": ["GeneralModelForCausalLMPipe"],
4572
}
4673

4774
if TYPE_CHECKING:
@@ -53,7 +80,9 @@
5380
from .linear import *
5481
from .lm_head import *
5582
from .mlp import *
83+
from .moe import *
5684
from .norm import *
85+
from .pp_model import *
5786
else:
5887
sys.modules[__name__] = _LazyModule(
5988
__name__,

paddleformers/nn/criterion/interface.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@
2121
from ..general import GeneralInterface
2222
from .dpo_loss import dpo_loss_forward
2323
from .kto_loss import kto_loss_forward
24-
from .sft_loss import sft_loss_forward
24+
from .sft_loss import mtp_sft_loss_forward, sft_loss_forward
2525

2626

2727
class LossInterface(GeneralInterface):
2828

29-
_global_mapping = {"sft": sft_loss_forward, "dpo": dpo_loss_forward, "kto": kto_loss_forward}
29+
_global_mapping = {
30+
"sft": sft_loss_forward,
31+
"dpo": dpo_loss_forward,
32+
"kto": kto_loss_forward,
33+
"mtp_sft": mtp_sft_loss_forward,
34+
}
3035

3136

3237
ALL_LOSS_FUNCTIONS = LossInterface()
@@ -40,16 +45,12 @@ def __init__(self, config, return_tuple=True, ignore_eos_token=False, use_infohu
4045
self.kto_config = copy.deepcopy(config.get("kto_config", None))
4146
self.ignored_index = getattr(config, "ignored_index", -100)
4247
self.use_filtered_label_loss = config.get("use_filtered_label_loss", False)
43-
self.loss_subbatch_seqlen = config.get(
44-
"loss_subbatch_seqlen", -1
45-
) # 切分由loss_subbatch_seqlen决定是否开启,loss_subbatch_seqlen > 0 才启动
48+
self.loss_subbatch_seqlen = config.get("loss_subbatch_seqlen", -1)
4649
self.use_subbatch = self.loss_subbatch_seqlen > 0
4750
self.sequence_parallel = config.get("sequence_parallel", False)
4851
self.tensor_parallel = config.tensor_parallel_degree > 1
4952
self.use_fused_head_and_loss_fn = config.get("use_fused_head_and_loss_fn", False)
50-
self.enable_parallel_cross_entropy = (
51-
config.tensor_parallel_degree > 1 and config.tensor_parallel_output
52-
) # loss并行计算时,use_fused_head_and_loss_fn = False
53+
self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output
5354
logger.info(
5455
f"loss_subbatch_seqlen: {self.loss_subbatch_seqlen} , use_fused_head_and_loss_fn: {self.use_fused_head_and_loss_fn}, use_filtered_label_loss: {self.use_filtered_label_loss}"
5556
)
@@ -80,6 +81,9 @@ def __init__(self, config, return_tuple=True, ignore_eos_token=False, use_infohu
8081
else:
8182
loss_type = "sft"
8283

84+
if config.get("num_nextn_predict_layers", 0) > 0:
85+
loss_type = "mtp_sft"
86+
8387
self.loss_foward_fn = ALL_LOSS_FUNCTIONS.get(loss_type)
8488
self.loss_type = loss_type
8589

paddleformers/nn/criterion/sft_loss.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def sft_loss_forward(
6161
logits, labels, hidden_states, lm_head_weight, lm_head_bias, transpose_y = sft_preprocess_inputs(
6262
self, logits, labels
6363
)
64-
6564
if self.use_filtered_label_loss:
6665
if self.tensor_parallel and self.sequence_parallel and logits is None:
6766
masked_lm_labels, sparse_label_idx = sequence_parallel_sparse_mask_labels(labels, self.ignored_index)
@@ -70,21 +69,21 @@ def sft_loss_forward(
7069
hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0)
7170
hidden_states = AllGatherVarlenOp.apply(hidden_states)
7271
else:
73-
masked_lm_labels = masked_lm_labels.flatten()
72+
masked_lm_labels = labels.flatten()
7473
sparse_label_idx = paddle.nonzero(masked_lm_labels != self.ignored_index).flatten()
7574
masked_lm_labels = paddle.take_along_axis(masked_lm_labels, sparse_label_idx, axis=0)
7675
if hidden_states is not None:
7776
hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
7877
hidden_states = paddle.take_along_axis(hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0)
7978
if logits is not None:
8079
logits = paddle.gather(logits, sparse_label_idx, axis=1)
80+
labels = masked_lm_labels
8181
else:
8282
if self.sequence_parallel:
8383
if hidden_states is not None:
8484
hidden_states = AllGatherOp.apply(hidden_states)
8585

86-
masked_lm_labels = labels
87-
86+
masked_lm_labels = labels
8887
# bsz,seq_len,hidden_size or seq_len,hidden_size
8988
seq_len = masked_lm_labels.shape[1] if masked_lm_labels.ndim == 2 else masked_lm_labels.shape[0]
9089
if self.use_fused_head_and_loss_fn and self.use_subbatch and seq_len > self.loss_subbatch_seqlen:
@@ -145,3 +144,57 @@ def sft_loss_forward(
145144
masked_lm_loss = self.loss_func(logits, labels.unsqueeze(-1))
146145
loss = sft_postprocess_loss(self, masked_lm_loss, labels, loss_mask, **kwargs)
147146
return loss
147+
148+
149+
def mtp_sft_loss_forward(
150+
self: nn.Layer,
151+
logits: Union[paddle.Tensor, Tuple[paddle.Tensor]],
152+
labels: Union[paddle.Tensor, Tuple[paddle.Tensor]],
153+
loss_mask: paddle.Tensor = None,
154+
router_loss: paddle.Tensor = None,
155+
mtp_logits: paddle.Tensor = None,
156+
**kwargs
157+
):
158+
num_nextn_predict_layers = self.config.get("num_nextn_predict_layers", 0)
159+
multi_token_pred_lambda = self.config.get("multi_token_pred_lambda", 0.3)
160+
if num_nextn_predict_layers > 0:
161+
labels_ori = labels
162+
labels = labels[:, :-num_nextn_predict_layers]
163+
if loss_mask is not None:
164+
loss_mask = loss_mask[:, :-num_nextn_predict_layers]
165+
seq_length = labels.shape[1]
166+
167+
sft_loss = sft_loss_forward(self, logits, labels, loss_mask, **kwargs)
168+
169+
if num_nextn_predict_layers > 0:
170+
mtp_loss_res = []
171+
for depth in range(num_nextn_predict_layers):
172+
logtis_cur_depth = mtp_logits[depth]
173+
labels_cur_depth = labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]
174+
res_cur_depth = sft_loss_forward(logtis_cur_depth, labels_cur_depth, loss_mask)
175+
mtp_loss_res.append(res_cur_depth)
176+
177+
def add_loss(main_loss, loss):
178+
return main_loss + loss - loss.detach()
179+
180+
if self.return_tuple:
181+
loss, loss_sum = sft_loss
182+
else:
183+
loss, loss_sum = sft_loss, None
184+
185+
if num_nextn_predict_layers > 0:
186+
loss = add_loss(
187+
loss,
188+
multi_token_pred_lambda * sum([x[0] for x in mtp_loss_res]) / len(mtp_loss_res),
189+
)
190+
191+
if loss_sum is not None:
192+
loss_sum = loss_sum + multi_token_pred_lambda * sum([x[1].detach() for x in mtp_loss_res]) / len(mtp_loss_res)
193+
194+
if router_loss is not None and isinstance(router_loss, paddle.Tensor):
195+
loss = loss + router_loss - router_loss.detach()
196+
197+
if self.return_tuple:
198+
return loss, loss_sum
199+
else:
200+
return loss

0 commit comments

Comments
 (0)