Skip to content

Commit b2a24c6

Browse files
authored
[New features] add roberta & gpt conversion (#4407)
* add roberta & gpt conversion * update gpt model * revert roberta related files * update gpt loading * update requirements * fix input_ids
1 parent d9cd8c3 commit b2a24c6

File tree

9 files changed

+241
-27
lines changed

9 files changed

+241
-27
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
- name: Install dependencies
1919
run: |
2020
python -m pip install --upgrade pip
21+
pip install -r tests/requirements.txt
2122
make install
2223
- name: run the command
2324
run: make test

paddlenlp/transformers/conversion_utils.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def should_merge_last_two_dim(self) -> bool:
255255
"""check that wether merge last two dim"""
256256
return self.action == "merge_last_two_dim"
257257

258-
def run(self, tensor: ndarray) -> ndarray:
258+
def run(self, state_dict: dict[str, ndarray], name: str) -> ndarray:
259259
"""run some custom operation on ndarray, eg: transpose, merge_last_two_dim
260260
261261
Args:
@@ -264,12 +264,21 @@ def run(self, tensor: ndarray) -> ndarray:
264264
Returns:
265265
ndarray: the final tensor
266266
"""
267+
tensor = state_dict.pop(name)
267268
if self.action == "transpose":
268269
return transpose(tensor, [1, 0])
269270
if self.action == "merge_last_two_dim":
270271
shape = tensor.shape
271272
assert len(shape) == 3
272273
return np.reshape(tensor, [shape[0], -1])
274+
if self.action == "split":
275+
assert self.index is not None, "when action is `split`, index field is required."
276+
277+
if self.index < 2:
278+
state_dict[name] = tensor
279+
# qkv is stored in same tensor, so it should be split into 3 arr
280+
tensors = np.split(tensor, 3, axis=-1)
281+
return tensors[self.index]
273282
return tensor
274283

275284
def matched(self, text: str) -> bool:
@@ -490,6 +499,9 @@ class LogitComparer:
490499
config_fields_to_be_removed: List[str] = ["transformers_version"]
491500
architectures: Dict[str, Type[PretrainedModel]] = {}
492501

502+
def __init__(self, input_dir: str) -> None:
503+
self.input_dir = input_dir
504+
493505
def get_paddle_pytorch_model_classes(self) -> Tuple[object, object]:
494506
"""return the [PaddleModelClass, PytorchModelClass] to
495507
1. generate paddle model automatically
@@ -574,13 +586,15 @@ def compare_model_state_dicts(
574586
for name_mapping in name_mappings:
575587
model_state_saver.add(name_mapping.target_name, "pytorch_key", name_mapping.source_name)
576588

577-
paddle_numpy = paddle_state_dict.pop(name_mapping.target_name)
578-
model_state_saver.add(name_mapping.target_name, "paddle", paddle_numpy)
579-
model_state_saver.add(name_mapping.target_name, "paddle-shape", str(paddle_numpy.shape))
589+
if name_mapping.target_name in paddle_state_dict:
590+
paddle_numpy = paddle_state_dict.pop(name_mapping.target_name)
591+
model_state_saver.add(name_mapping.target_name, "paddle", paddle_numpy)
592+
model_state_saver.add(name_mapping.target_name, "paddle-shape", str(paddle_numpy.shape))
580593

581-
pytorch_numpy = pytorch_state_dict.pop(name_mapping.source_name)
582-
model_state_saver.add(name_mapping.target_name, "pytorch", pytorch_numpy)
583-
model_state_saver.add(name_mapping.target_name, "pytorch-shape", str(pytorch_numpy.shape))
594+
if name_mapping.source_name in pytorch_state_dict:
595+
pytorch_numpy = pytorch_state_dict.pop(name_mapping.source_name)
596+
model_state_saver.add(name_mapping.target_name, "pytorch", pytorch_numpy)
597+
model_state_saver.add(name_mapping.target_name, "pytorch-shape", str(pytorch_numpy.shape))
584598

585599
model_state_saver.summary()
586600

@@ -594,8 +608,7 @@ def compare_logits(self) -> bool:
594608
paddle_model = PaddleModel.from_pretrained(self.input_dir)
595609

596610
# 0. init the name_mapping & tensor_info_saver & logit_hooker
597-
num_layers = self.get_num_layer(list(paddle_model.state_dict().keys()))
598-
name_mappings = self.get_name_mapping(num_layers, paddle_model.config["architectures"])
611+
name_mappings = self.get_name_mapping(paddle_model.config)
599612
tensor_info_saver = TensorInfoSaver()
600613

601614
logit_hooker = LogitHooker(name_mappings, tensor_info_saver)
@@ -707,8 +720,9 @@ def convert(cls, weight_file: str, config: PretrainedConfig, cache_dir: str) ->
707720
logger.warning(f"key<{name_mapping.source_name}> not in the pytorch weight file.")
708721
continue
709722

710-
state_dict[name_mapping.target_name] = name_mapping.run(state_dict.pop(name_mapping.source_name))
711-
all_layer_names.remove(name_mapping.source_name)
723+
state_dict[name_mapping.target_name] = name_mapping.run(state_dict, name_mapping.source_name)
724+
if name_mapping.source_name in all_layer_names:
725+
all_layer_names.remove(name_mapping.source_name)
712726

713727
if all_layer_names:
714728
logger.warning(f"there are {len(all_layer_names)} tensors not initialized:")

paddlenlp/transformers/gpt/configuration.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,18 @@ class GPTConfig(PretrainedConfig):
240240
>>> configuration = model.config
241241
```"""
242242
model_type = "gpt"
243-
attribute_map: Dict[str, str] = {"num_classes": "num_labels", "dropout": "classifier_dropout"}
243+
attribute_map: Dict[str, str] = {
244+
"num_classes": "num_labels",
245+
"dropout": "classifier_dropout",
246+
"n_positions": "max_position_embeddings",
247+
"n_embd": "hidden_size",
248+
"n_layer": "num_hidden_layers",
249+
"n_head": "num_attention_heads",
250+
"n_inner": "intermediate_size",
251+
"activation_function": "hidden_act",
252+
"resid_pdrop": "attention_probs_dropout_prob",
253+
}
254+
244255
pretrained_init_configuration = GPT_PRETRAINED_INIT_CONFIGURATION
245256

246257
def __init__(

paddlenlp/transformers/gpt/modeling.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from __future__ import annotations
1617

1718
import collections
1819

@@ -24,6 +25,7 @@
2425
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2526
from paddle.nn.layer.transformer import _convert_param_attr_to_list
2627

28+
from ...utils.converter import StateDictNameMapping
2729
from ...utils.log import logger
2830
from .. import PretrainedModel, register_base_model
2931
from ..model_outputs import (
@@ -460,6 +462,79 @@ class GPTPretrainedModel(PretrainedModel):
460462
base_model_prefix = "gpt"
461463
config_class = GPTConfig
462464

465+
@classmethod
466+
def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:
467+
mappings: list[StateDictNameMapping] = []
468+
model_mappings = [
469+
["wte.weight", "embeddings.word_embeddings.weight"],
470+
["wpe.weight", "embeddings.position_embeddings.weight"],
471+
["ln_f.weight", "decoder.norm.weight"],
472+
["ln_f.bias", "decoder.norm.bias"],
473+
]
474+
for layer_index in range(config.num_hidden_layers):
475+
layer_mappings = [
476+
[f"h.{layer_index}.ln_1.weight", f"decoder.layers.{layer_index}.norm1.weight"],
477+
[f"h.{layer_index}.ln_1.bias", f"decoder.layers.{layer_index}.norm1.bias"],
478+
[f"h.{layer_index}.ln_2.weight", f"decoder.layers.{layer_index}.norm2.weight"],
479+
[f"h.{layer_index}.ln_2.bias", f"decoder.layers.{layer_index}.norm2.bias"],
480+
[f"h.{layer_index}.mlp.c_fc.weight", f"decoder.layers.{layer_index}.linear1.weight"],
481+
[f"h.{layer_index}.mlp.c_fc.bias", f"decoder.layers.{layer_index}.linear1.bias"],
482+
[f"h.{layer_index}.mlp.c_proj.weight", f"decoder.layers.{layer_index}.linear2.weight"],
483+
[f"h.{layer_index}.mlp.c_proj.bias", f"decoder.layers.{layer_index}.linear2.bias"],
484+
[f"h.{layer_index}.attn.c_proj.weight", f"decoder.layers.{layer_index}.self_attn.out_proj.weight"],
485+
[f"h.{layer_index}.attn.c_proj.bias", f"decoder.layers.{layer_index}.self_attn.out_proj.bias"],
486+
# attention
487+
[
488+
f"h.{layer_index}.attn.c_attn.weight",
489+
f"decoder.layers.{layer_index}.self_attn.q_proj.weight",
490+
"split",
491+
0,
492+
],
493+
[
494+
f"h.{layer_index}.attn.c_attn.bias",
495+
f"decoder.layers.{layer_index}.self_attn.q_proj.bias",
496+
"split",
497+
0,
498+
],
499+
[
500+
f"h.{layer_index}.attn.c_attn.weight",
501+
f"decoder.layers.{layer_index}.self_attn.k_proj.weight",
502+
"split",
503+
1,
504+
],
505+
[
506+
f"h.{layer_index}.attn.c_attn.bias",
507+
f"decoder.layers.{layer_index}.self_attn.k_proj.bias",
508+
"split",
509+
1,
510+
],
511+
[
512+
f"h.{layer_index}.attn.c_attn.weight",
513+
f"decoder.layers.{layer_index}.self_attn.v_proj.weight",
514+
"split",
515+
2,
516+
],
517+
[
518+
f"h.{layer_index}.attn.c_attn.bias",
519+
f"decoder.layers.{layer_index}.self_attn.v_proj.bias",
520+
"split",
521+
2,
522+
],
523+
]
524+
525+
model_mappings.extend(layer_mappings)
526+
527+
if "GPT2Model" not in config.architectures:
528+
for mapping in model_mappings:
529+
mapping[0] = "transformer." + mapping[0]
530+
mapping[1] = "gpt." + mapping[1]
531+
532+
if "GPT2LMHeadModel" in config.architectures:
533+
model_mappings.append(["lm_head.weight", "lm_head.decoder_weight"])
534+
535+
mappings = [StateDictNameMapping(*mapping) for mapping in model_mappings]
536+
return mappings
537+
463538
def init_weights(self, layer):
464539
"""Initialization hook"""
465540
if isinstance(layer, (nn.Linear, nn.Embedding)):

paddlenlp/transformers/roberta/modeling.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
from __future__ import annotations
1617

1718
import paddle
1819
import paddle.nn as nn
1920
import paddle.nn.functional as F
2021

21-
from dataclasses import dataclass
22-
from typing import List, Optional, Tuple, Union
2322
from .. import PretrainedModel, register_base_model
2423
from ..model_outputs import (
2524
BaseModelOutputWithPoolingAndCrossAttentions,
25+
CausalLMOutputWithCrossAttentions,
26+
MaskedLMOutput,
27+
MultipleChoiceModelOutput,
28+
QuestionAnsweringModelOutput,
2629
SequenceClassifierOutput,
2730
TokenClassifierOutput,
28-
QuestionAnsweringModelOutput,
29-
MultipleChoiceModelOutput,
30-
MaskedLMOutput,
31-
CausalLMOutputWithCrossAttentions,
3231
)
3332
from .configuration import PRETRAINED_INIT_CONFIGURATION, RobertaConfig
3433

paddlenlp/utils/serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from __future__ import annotations
1615

1716
import io
@@ -209,7 +208,7 @@ def persistent_load_stage1(saved_id):
209208
result_stage1 = unpickler_stage1.load()
210209

211210
# 2. get the metadata of weight file
212-
metadata = []
211+
metadata = {}
213212

214213
def extract_maybe_dict(result):
215214
if isinstance(result, dict):
@@ -219,11 +218,12 @@ def extract_maybe_dict(result):
219218
for res in result:
220219
extract_maybe_dict(res)
221220
elif isinstance(result, TensorMeta):
222-
if result not in metadata:
223-
metadata.append(result)
221+
metadata[result.key] = result
224222

225223
extract_maybe_dict(result_stage1)
224+
metadata = list(metadata.values())
226225
metadata = sorted(metadata, key=lambda x: x.key)
226+
227227
# 3. parse the tensor of pytorch weight file
228228
stage1_key_to_tensor = {}
229229
content_size = os.stat(path).st_size

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
paddlepaddle==2.4.0rc0
1+
paddlepaddle>=2.4.1
22
pre-commit
33
pytest
44
parameterized

0 commit comments

Comments
 (0)