Skip to content

Commit e4d3a9a

Browse files
authored
Merge branch 'PaddlePaddle:develop' into dev_20250108_fix_requirements
2 parents c528162 + fb60645 commit e4d3a9a

File tree

13 files changed

+2753
-12
lines changed

13 files changed

+2753
-12
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,15 @@ def loss_func(loss, outputs):
105105
model = kwargs["model"]
106106
for param in model.parameters():
107107
if not param._is_initialized():
108-
param.initialize()
108+
try:
109+
param.initialize()
110+
except Exception as e:
111+
# NOTE(zhangwl):maybe param is not initialized and param init_func is set in later.user need set_init_func before auto_trainer
112+
logger.warning(
113+
f"AutoTrainer requires all parameters to be initialized when auto_trainer init, but failed to initialize parameter {param.name} {param}.\n"
114+
+ "Please check param init func.\n"
115+
+ f"The original exception message is:\n{str(e)}"
116+
)
109117
kwargs["model"] = model
110118

111119
super().__init__(*args, **kwargs)

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141

4242
class CheckpointConverter:
4343
def __init__(
44-
self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, trainging_args=None, patch_dict=None
44+
self,
45+
hybrid_parallel_ckpt_path,
46+
state_dict,
47+
parameter_to_structured_name,
48+
trainging_args=None,
49+
patch_dict=None,
50+
local_view_pattern: list | bool = None,
4551
):
4652
self.use_dist = True if paddle.distributed.get_world_size() > 1 else False
4753
self.path = hybrid_parallel_ckpt_path
@@ -85,6 +91,17 @@ def __init__(
8591
self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k]
8692
for k in del_keys:
8793
self.auto_parallel_state_dict.pop(k)
94+
# solve the problem of inconsistent parameter names in moe automatic parallel mode.
95+
if hasattr(trainging_args, "moe_group") and trainging_args.moe_group:
96+
if local_view_pattern is False:
97+
self.local_view_pattern_list = None
98+
else:
99+
if isinstance(local_view_pattern, list):
100+
self.local_view_pattern_list = local_view_pattern
101+
else:
102+
self.local_view_pattern_list = ["experts"]
103+
else:
104+
self.local_view_pattern_list = None
88105

89106
flags = [
90107
["tp degree", self.tp_degree],
@@ -497,6 +514,46 @@ def gen_metadata_and_prepare_source_state_dict(self):
497514
else:
498515
return self.gen_metadata_for_tp_sharded_tensor()
499516

517+
def rename_local_view_state_dict(self, state_dict, file_name):
518+
"""
519+
Rename the key for local views to the key for global views, and return the renamed `state_dict`.
520+
"""
521+
if self.local_view_pattern_list is None:
522+
return state_dict
523+
# case 1: moe_group is mp_group
524+
if self.tp_degree > 1 and self.sharding_degree <= 1:
525+
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
526+
expert_name_old2new = {}
527+
for pattern in self.local_view_pattern_list:
528+
expert_pattern = rf"({pattern}\.)(\d+)"
529+
# extract all experts IDs
530+
expert_ids = set()
531+
for state_name in state_dict.keys():
532+
res = re.search(expert_pattern, state_name)
533+
if res:
534+
expert_ids.add(int(res.group(2)))
535+
expert_num = len(expert_ids)
536+
# construct old name to new name mapping
537+
for state_name in state_dict.keys():
538+
res = re.search(expert_pattern, state_name)
539+
if res:
540+
new_expert_id = int(res.group(2)) % expert_num + tp_rank * expert_num
541+
expert_name_old2new[state_name] = re.sub(
542+
expert_pattern, f"{res.group(1)}{new_expert_id}", state_name
543+
)
544+
# rename state_dict
545+
renamed_state_dict = {
546+
expert_name_old2new[state_name]
547+
if state_name in expert_name_old2new
548+
else state_name: state_dict[state_name]
549+
for state_name in state_dict.keys()
550+
}
551+
552+
return renamed_state_dict
553+
# TODO: add support for sharding
554+
else:
555+
return state_dict
556+
500557
def load_state_dict_and_rename(self):
501558
"""
502559
Parse the distributed information from the names of the checkpoint files and evenly parse out the distributed information for each weight/optimizer state
@@ -741,11 +798,10 @@ def load_state_dict_and_rename(self):
741798
model_state_file_name = self.get_model_state_file_from(file_name)
742799
assert model_state_file_name is not None
743800
model_state_keys = global_file_to_state_dict_keys_mapping[model_state_file_name]
744-
renamed_state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict)
745-
self.get_sharded_tensor_infos(file, renamed_state_dict, cur_rank_sharded_tensor_infos)
746-
self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict
747-
else:
748-
self.get_sharded_tensor_infos(file_name, state_dict, cur_rank_sharded_tensor_infos)
801+
state_dict = self.rename_using_optimizer_state_order(model_state_keys, state_dict)
802+
renamed_state_dict = self.rename_local_view_state_dict(state_dict, file_name)
803+
self.get_sharded_tensor_infos(file_name, renamed_state_dict, cur_rank_sharded_tensor_infos)
804+
self.cur_rank_loaded_state_dict[file_name] = renamed_state_dict
749805
else:
750806
for file, state_dict in self.cur_rank_loaded_state_dict.items():
751807
# The rule for renaming is to change the master_weights name in the optimizer state to the model weight name,
@@ -897,6 +953,9 @@ def rename(old_name, parameter_to_structured_name):
897953
return None
898954

899955
for key, value in state_dict.items():
956+
# NOTE: Skip the parameters that are not initialized,which are not in the current rank.
957+
if value is None or (isinstance(value, paddle.Tensor) and not value._is_initialized()):
958+
continue
900959
if key in parameter_to_structured_name.values():
901960
new_name = key
902961
else:
@@ -909,7 +968,9 @@ def rename(old_name, parameter_to_structured_name):
909968
def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):
910969
name_mapping = {}
911970
suffix_bucket = {}
912-
assert len(optimizer_state_dict) % len(model_state_keys) == 0
971+
# TODO: After adapting to sharding, remove the code below.
972+
if self.is_sharding_stage3 or (self.sharding_degree > 1 and self.sharding_stage1_v == 2):
973+
assert len(optimizer_state_dict) % len(model_state_keys) == 0
913974
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
914975
suffix_bucket[suffix] = []
915976
for opt_name, opt_value in optimizer_state_dict.items():
@@ -927,10 +988,27 @@ def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_d
927988
for suffix, old_names in suffix_bucket.items():
928989
if len(old_names) == 0:
929990
continue
930-
assert len(old_names) == len(model_state_keys)
931-
for i in range(len(old_names)):
932-
name_mapping[old_names[i]] = model_state_keys[i] + suffix
933-
991+
# TODO: After adapting to sharding, remove the code below.
992+
if self.is_sharding_stage3 or (self.sharding_degree > 1 and self.sharding_stage1_v == 2):
993+
assert len(old_names) == len(model_state_keys)
994+
995+
# NOTE: Handle the case where the number of master_weight elements is not equal to the number of model_state_keys.
996+
if suffix != ".master_weight":
997+
for i in range(len(old_names)):
998+
name_mapping[old_names[i]] = model_state_keys[i] + suffix
999+
else:
1000+
for i in range(len(old_names)):
1001+
param = old_names[i][:-14]
1002+
index = -1
1003+
for idx, opt_name in enumerate(suffix_bucket[".moment1"]):
1004+
if param == opt_name[:-24]:
1005+
index = idx
1006+
break
1007+
if index >= 0:
1008+
name_mapping[old_names[i]] = model_state_keys[index] + suffix
1009+
else:
1010+
raise RuntimeError(f"Can't find {param} in optimizer state dict.")
1011+
# rename state dict
9341012
renamed_state_dict = {}
9351013
for k, v in optimizer_state_dict.items():
9361014
renamed_state_dict[name_mapping[k]] = v

paddlenlp/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@
209209
from .xlm.modeling import *
210210
from .xlm.tokenizer import *
211211
from .xlm.configuration import *
212+
from .xlm_roberta.modeling import *
213+
from .xlm_roberta.tokenizer import *
214+
from .xlm_roberta.configuration import *
212215
from .gau_alpha.modeling import *
213216
from .gau_alpha.tokenizer import *
214217
from .gau_alpha.configuration import *

paddlenlp/transformers/auto/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
("unimo", "UNIMOConfig"),
114114
("visualglm", "VisualGLMConfig"),
115115
("xlm", "XLMConfig"),
116+
("xlm-roberta", "XLMRobertaConfig"),
116117
("xlnet", "XLNetConfig"),
117118
("yuan", "YuanConfig"),
118119
]
@@ -202,6 +203,7 @@
202203
("unimo", "UNIMO"),
203204
("visualglm", "VisualGLM"),
204205
("xlm", "XLM"),
206+
("xlm-roberta", "XLMRoberta"),
205207
("xlnet", "XLNet"),
206208
("yuan", "Yuan"),
207209
]

paddlenlp/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
("UNIMO", "unimo"),
9595
("XLNet", "xlnet"),
9696
("XLM", "xlm"),
97+
("XLMRoberta", "xlm_roberta"),
9798
("GPT", "gpt"),
9899
("GLM", "glm"),
99100
("MT5", "mt5"),

paddlenlp/transformers/auto/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
("squeezebert", "SqueezeBertTokenizer"),
100100
("t5", "T5Tokenizer"),
101101
("xlm", "XLMTokenizer"),
102+
("xlm_roberta", "XLMRobertaTokenizer"),
102103
("xlnet", "XLNetTokenizer"),
103104
("bert_japanese", "BertJapaneseTokenizer"),
104105
("bigbird", "BigBirdTokenizer"),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .configuration import *
16+
from .modeling import *
17+
from .tokenizer import *
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# coding=utf-8
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
""" XLM-RoBERTa configuration"""
18+
19+
from ..model_utils import PretrainedConfig
20+
21+
__all__ = ["PRETRAINED_INIT_CONFIGURATION", "XLMRobertaConfig"]
22+
23+
PRETRAINED_INIT_CONFIGURATION = {
24+
"hf-internal-testing/tiny-random-onnx-xlm-roberta": {
25+
"attention_probs_dropout_prob": 0.1,
26+
"bos_token_id": 0,
27+
"classifier_dropout": None,
28+
"eos_token_id": 2,
29+
"hidden_act": "gelu",
30+
"hidden_dropout_prob": 0.1,
31+
"hidden_size": 4,
32+
"initializer_range": 0.02,
33+
"intermediate_size": 37,
34+
"layer_norm_eps": 1e-05,
35+
"max_position_embeddings": 514,
36+
"model_type": "xlm-roberta",
37+
"num_attention_heads": 4,
38+
"num_hidden_layers": 5,
39+
"output_past": True,
40+
"pad_token_id": 1,
41+
"position_embedding_type": "absolute",
42+
"dtype": "float32",
43+
"type_vocab_size": 1,
44+
"use_cache": True,
45+
"vocab_size": 250002,
46+
},
47+
}
48+
49+
50+
class XLMRobertaConfig(PretrainedConfig):
51+
r"""
52+
This is the configuration class to store the configuration of a [`XLMRobertaModel`] or a [`TFXLMRobertaModel`]. It
53+
is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
54+
Instantiating a configuration with the defaults will yield a similar configuration to that of the XLMRoBERTa
55+
[xlm-roberta-base](https://huggingface.co/xlm-roberta-base) architecture.
56+
57+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
58+
documentation from [`PretrainedConfig`] for more information.
59+
60+
61+
Args:
62+
vocab_size (`int`, *optional*, defaults to 30522):
63+
Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by
64+
the `inputs_ids` passed when calling [`XLMRobertaModel`] or [`TFXLMRobertaModel`].
65+
hidden_size (`int`, *optional*, defaults to 768):
66+
Dimensionality of the encoder layers and the pooler layer.
67+
num_hidden_layers (`int`, *optional*, defaults to 12):
68+
Number of hidden layers in the Transformer encoder.
69+
num_attention_heads (`int`, *optional*, defaults to 12):
70+
Number of attention heads for each attention layer in the Transformer encoder.
71+
intermediate_size (`int`, *optional*, defaults to 3072):
72+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
73+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
74+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
75+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
76+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
77+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
78+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
79+
The dropout ratio for the attention probabilities.
80+
max_position_embeddings (`int`, *optional*, defaults to 512):
81+
The maximum sequence length that this model might ever be used with. Typically set this to something large
82+
just in case (e.g., 512 or 1024 or 2048).
83+
type_vocab_size (`int`, *optional*, defaults to 2):
84+
The vocabulary size of the `token_type_ids` passed when calling [`XLMRobertaModel`] or
85+
[`TFXLMRobertaModel`].
86+
initializer_range (`float`, *optional*, defaults to 0.02):
87+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
88+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
89+
The epsilon used by the layer normalization layers.
90+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
91+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
92+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
93+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
94+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
95+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
96+
is_decoder (`bool`, *optional*, defaults to `False`):
97+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
98+
use_cache (`bool`, *optional*, defaults to `True`):
99+
Whether or not the model should return the last key/values attentions (not used by all models). Only
100+
relevant if `config.is_decoder=True`.
101+
classifier_dropout (`float`, *optional*):
102+
The dropout ratio for the classification head.
103+
104+
Examples:
105+
106+
```python
107+
>>> from paddlenlp.transformers import XLMRobertaConfig, XLMRobertaModel
108+
109+
>>> # Initializing a XLM-RoBERTa xlm-roberta-base style configuration
110+
>>> configuration = XLMRobertaConfig()
111+
112+
>>> # Initializing a model (with random weights) from the xlm-roberta-base style configuration
113+
>>> model = XLMRobertaModel(configuration)
114+
115+
>>> # Accessing the model configuration
116+
>>> configuration = model.config
117+
```"""
118+
119+
model_type = "xlm-roberta"
120+
121+
def __init__(
122+
self,
123+
vocab_size=30522,
124+
hidden_size=768,
125+
num_hidden_layers=12,
126+
num_attention_heads=12,
127+
intermediate_size=3072,
128+
hidden_act="gelu",
129+
hidden_dropout_prob=0.1,
130+
attention_probs_dropout_prob=0.1,
131+
max_position_embeddings=512,
132+
type_vocab_size=2,
133+
initializer_range=0.02,
134+
layer_norm_eps=1e-12,
135+
pad_token_id=1,
136+
bos_token_id=0,
137+
eos_token_id=2,
138+
position_embedding_type="absolute",
139+
use_cache=True,
140+
classifier_dropout=None,
141+
**kwargs,
142+
):
143+
kwargs["return_dict"] = kwargs.pop("return_dict", False)
144+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
145+
146+
self.vocab_size = vocab_size
147+
self.hidden_size = hidden_size
148+
self.num_hidden_layers = num_hidden_layers
149+
self.num_attention_heads = num_attention_heads
150+
self.hidden_act = hidden_act
151+
self.intermediate_size = intermediate_size
152+
self.hidden_dropout_prob = hidden_dropout_prob
153+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
154+
self.max_position_embeddings = max_position_embeddings
155+
self.type_vocab_size = type_vocab_size
156+
self.initializer_range = initializer_range
157+
self.layer_norm_eps = layer_norm_eps
158+
self.position_embedding_type = position_embedding_type
159+
self.use_cache = use_cache
160+
self.classifier_dropout = classifier_dropout

0 commit comments

Comments
 (0)