Skip to content

Commit e900f1d

Browse files
authored
Fix attn impl and ernie4.5 for erniekit (#2580)
1 parent b310142 commit e900f1d

File tree

10 files changed

+95
-8
lines changed

10 files changed

+95
-8
lines changed

examples/run_finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def main():
171171
model_config.seq_length = training_args.max_seq_len
172172
model_config.max_sequence_length = training_args.max_seq_len
173173
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
174+
model_config._attn_implementation = model_args.attn_impl
174175
logger.info(f"Final model config: {model_config}")
175176
logger.info("Creating model")
176177

paddleformers/nn/attention/flashmask_attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def flashmask_attention_forward(
3434
**kwargs
3535
):
3636
# b,l,h,d
37+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.ndim == 3:
38+
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
39+
3740
if sink is None:
3841
out = flashmask_attention(
3942
query,

paddleformers/nn/attention/sdpa_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def sdpa_attention_forward(
3939
is_causal = query.shape[1] > 1 and attention_mask is None and getattr(module, "is_causal", True)
4040
elif attn_mask_startend_row_indices is not None:
4141
is_causal = False
42+
if attn_mask_startend_row_indices.ndim == 3:
43+
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
4244
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_startend_row_indices, query.dtype)
4345

4446
if sink is None:

paddleformers/nn/criterion/sft_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def sft_preprocess_inputs(self, logits, labels):
3636

3737

3838
def sft_postprocess_loss(self, masked_lm_loss, labels, loss_mask, **kwargs):
39-
if loss_mask is None:
39+
if self.use_filtered_label_loss or loss_mask is None:
4040
loss_mask = labels != self.ignored_index
4141
loss_mask = loss_mask.reshape([-1]).cast(paddle.float32)
4242
# 逐位对齐, 全精度聚合

paddleformers/nn/pp_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,14 @@ def _prepare_pipeline_inputs_func(cls, inputs):
653653
"position_ids",
654654
"nbatch_pack_offset",
655655
]
656+
# (NOTE) attn_mask_start_row_indices is special for erniekit
657+
elif "attn_mask_start_row_indices" in inputs:
658+
first_stage_keys = [
659+
"input_ids",
660+
"attn_mask_start_row_indices",
661+
"position_ids",
662+
"nbatch_pack_offset",
663+
]
656664
else: # inputs is list
657665
if "attention_mask" in inputs[0]:
658666
first_stage_keys = [
@@ -661,6 +669,13 @@ def _prepare_pipeline_inputs_func(cls, inputs):
661669
"position_ids",
662670
"nbatch_pack_offset",
663671
]
672+
elif "attn_mask_start_row_indices" in inputs[0]:
673+
first_stage_keys = [
674+
"input_ids",
675+
"attn_mask_start_row_indices",
676+
"position_ids",
677+
"nbatch_pack_offset",
678+
]
664679
last_stage_keys = ["labels", "loss_mask"]
665680

666681
def get_expected_keys(inputs, keys):

paddleformers/transformers/auto/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
("bert", "Bert"),
5656
("deepseek_v2", "DeepseekV2"),
5757
("deepseek_v3", "DeepseekV3"),
58+
("ernie4_5", "Ernie4_5"),
59+
("ernie4_5_moe", "Ernie4_5_Moe"),
5860
("llama", "Llama"),
5961
("qwen", "QWen"),
6062
("qwen2", "Qwen2"),

paddleformers/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
("DeepseekV2", "deepseek_v2"),
5757
("DeepseekV3", "deepseek_v3"),
5858
("Ernie4_5", "ernie4_5"),
59+
("Ernie4_5_Moe", "ernie4_5_moe"),
5960
("Llama", "llama"),
6061
("QWen", "qwen"),
6162
("Qwen2", "qwen2"),

paddleformers/transformers/ernie4_5/modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,9 @@ def forward(
770770
Returns:
771771
Union[tuple, CausalLMOutputWithCrossAttentions]: Model outputs.
772772
"""
773+
if kwargs.get("attn_mask_start_row_indices", None) is not None and attn_mask_startend_row_indices is None:
774+
attn_mask_startend_row_indices = kwargs.pop("attn_mask_start_row_indices")
775+
773776
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
774777
output_hidden_states = (
775778
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

paddleformers/transformers/ernie4_5_moe/configuration.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
""" Ernie4_5_Moe model configuration """
16+
import json
1617
from typing import Optional, Union
1718

1819
from ...utils.log import logger
@@ -233,7 +234,7 @@ def __init__(
233234
self.multi_token_pred_lambda = multi_token_pred_lambda
234235
self.enable_mtp_magic_send = enable_mtp_magic_send
235236
self.use_recompute_mtp = use_recompute_mtp
236-
237+
self.dpo_config = dpo_config
237238
self.register_unsavable_keys(
238239
[
239240
"disable_ffn_model_parallel",
@@ -275,5 +276,49 @@ def __init__(
275276
]
276277
)
277278

279+
def to_json_string(self, use_diff: bool = True, saving_file=False) -> str:
280+
"""
281+
Serialize the configuration to a JSON string with special handling for non-serializable objects.
282+
283+
This method overrides the default JSON serialization to handle special objects like
284+
paddle.distributed.communication.group.Group that cannot be serialized normally.
285+
286+
Args:
287+
use_diff (bool, optional): If True, only outputs the differences from the default configuration.
288+
If False, outputs the full configuration. Defaults to True.
289+
290+
Returns:
291+
str: A JSON formatted string representation of the configuration, with proper indentation
292+
and handling for non-serializable objects.
293+
"""
294+
if use_diff is True:
295+
config_dict = self.to_diff_dict(saving_file=saving_file)
296+
else:
297+
config_dict = self.to_dict(saving_file=saving_file)
298+
299+
def _serializer(obj):
300+
"""
301+
Handle non-serializable objects during JSON conversion.
302+
303+
Args:
304+
obj: The object to be serialized
305+
306+
Returns:
307+
The serializable representation of the object
308+
309+
"""
310+
return repr(obj)
311+
312+
return (
313+
json.dumps(
314+
config_dict,
315+
indent=2,
316+
sort_keys=True,
317+
ensure_ascii=False,
318+
default=_serializer,
319+
)
320+
+ "\n"
321+
)
322+
278323

279324
__all__ = ["Ernie4_5_MoeConfig"]

paddleformers/transformers/ernie4_5_moe/modeling.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,17 @@ class Ernie4_5_MoePretrainedModel(PretrainedModel):
443443
config_class = Ernie4_5_MoeConfig
444444
base_model_prefix = "model"
445445
_keep_in_fp32_modules = ["mlp.gate.weight", "e_score_correction_bias"]
446-
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "gate"]
446+
transpose_weight_keys = [
447+
"q_proj",
448+
"k_proj",
449+
"v_proj",
450+
"o_proj",
451+
"gate_proj",
452+
"up_proj",
453+
"down_proj",
454+
"gate",
455+
"mtp_linear_proj.0",
456+
]
447457

448458
@classmethod
449459
def _get_tensor_parallel_mappings(cls, config, is_split=True):
@@ -659,16 +669,18 @@ def __init__(self, config: Ernie4_5_MoeConfig):
659669
self.mtp_linear_proj = paddle.nn.LayerList(
660670
[
661671
GeneralLinear.create(
662-
self.config.hidden_size * 2,
663-
self.config.hidden_size,
672+
config.hidden_size * 2,
673+
config.hidden_size,
664674
has_bias=config.use_bias,
665675
config=config,
666676
fuse_matmul_bias=config.fuse_linear,
677+
linear_type="default",
667678
)
668-
for _ in range(self.config.num_nextn_predict_layers)
679+
for _ in range(config.num_nextn_predict_layers)
669680
]
670681
)
671682
if config.sequence_parallel:
683+
logger.info("enable sequence parallel for mtp_linear")
672684
for mtp_linear in self.mtp_linear_proj:
673685
mark_as_sequence_parallel_parameter(mtp_linear.weight)
674686
if config.use_bias:
@@ -795,7 +807,7 @@ def forward(
795807
attention_mask, inputs_embeds.shape[:2], kv_seq_len, inputs_embeds.dtype
796808
)
797809

798-
if self.config.num_nextn_predict_layers > 0:
810+
if self.training and self.config.num_nextn_predict_layers > 0:
799811
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :]
800812
inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
801813
inputs_embeds_ori = inputs_embeds
@@ -896,7 +908,7 @@ def forward(
896908
all_gate_logits = all_gate_logits + (gate_logits,)
897909

898910
# Multi Token Prediction
899-
if self.config.num_nextn_predict_layers > 0:
911+
if self.training and self.config.num_nextn_predict_layers > 0:
900912
mtp_outputs.append(hidden_states)
901913

902914
for depth in range(self.config.num_nextn_predict_layers):
@@ -1088,6 +1100,9 @@ def forward(
10881100
Returns:
10891101
Union[tuple, MoECausalLMOutputWithPast]: Model outputs.
10901102
"""
1103+
if kwargs.get("attn_mask_start_row_indices", None) is not None and attn_mask_startend_row_indices is None:
1104+
attn_mask_startend_row_indices = kwargs["attn_mask_start_row_indices"]
1105+
10911106
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
10921107
output_hidden_states = (
10931108
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

0 commit comments

Comments
 (0)