Skip to content

Commit 84615ea

Browse files
authored
Fix gpt download bug (#8253)
1 parent 814e9c4 commit 84615ea

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

paddlenlp/transformers/gpt/configuration.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from paddlenlp.transformers.configuration_utils import PretrainedConfig
2020

21-
__all__ = ["GPT_PRETRAINED_INIT_CONFIGURATION", "GPTConfig"]
21+
__all__ = ["GPT_PRETRAINED_INIT_CONFIGURATION", "GPTConfig", "GPT_PRETRAINED_RESOURCE_FILES_MAP"]
2222

2323
GPT_PRETRAINED_INIT_CONFIGURATION = {
2424
"gpt-cpm-large-cn": { # 2.6B
@@ -147,6 +147,17 @@
147147
},
148148
}
149149

150+
GPT_PRETRAINED_RESOURCE_FILES_MAP = {
151+
"model_state": {
152+
"gpt-cpm-large-cn": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-cpm-large-cn.pdparams",
153+
"gpt-cpm-small-cn-distill": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-cpm-small-cn-distill.pdparams",
154+
"gpt2-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-en.pdparams",
155+
"gpt2-medium-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-medium-en.pdparams",
156+
"gpt2-large-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-large-en.pdparams",
157+
"gpt2-xl-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-xl-en.pdparams",
158+
}
159+
}
160+
150161

151162
class GPTConfig(PretrainedConfig):
152163
r"""

paddlenlp/transformers/gpt/modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
TokenClassifierOutput,
5050
)
5151
from ..model_utils import dy2st_nocheck_guard_context
52-
from .configuration import GPT_PRETRAINED_INIT_CONFIGURATION, GPTConfig
52+
from .configuration import (
53+
GPT_PRETRAINED_INIT_CONFIGURATION,
54+
GPT_PRETRAINED_RESOURCE_FILES_MAP,
55+
GPTConfig,
56+
)
5357

5458
try:
5559
from paddle.nn.functional.flash_attention import flash_attention
@@ -787,6 +791,7 @@ class GPTPretrainedModel(PretrainedModel):
787791
base_model_prefix = "gpt"
788792
config_class = GPTConfig
789793
pretrained_init_configuration = GPT_PRETRAINED_INIT_CONFIGURATION
794+
pretrained_resource_files_map = GPT_PRETRAINED_RESOURCE_FILES_MAP
790795

791796
@classmethod
792797
def _get_tensor_parallel_mappings(cls, config, is_split=True):

0 commit comments

Comments
 (0)