Skip to content

Commit 797efa6

Browse files
authored
[Fix Download] update converted logic & fix hf hub download subfolder bug (#7911)
* update converted logic & fix hf hub download subfolder bug
1 parent fff730e commit 797efa6

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from paddle.utils.download import is_url as is_remote_url
5353
from tqdm.auto import tqdm
5454

55-
from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists
55+
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
5656
from paddlenlp.utils.env import (
5757
CONFIG_NAME,
5858
LEGACY_CONFIG_NAME,
@@ -367,28 +367,7 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver
367367
support_conversion (bool): whether support converting pytorch weight file to paddle weight file
368368
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
369369
"""
370-
is_local = os.path.isdir(repo_id)
371-
if not is_local:
372-
if hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder):
373-
file_name = PADDLE_WEIGHTS_NAME
374-
assert (
375-
support_conversion is False
376-
), "Please call set convert_from_torch for paddle weights on huggingface hub, eg. Model.from_pretrained(model_name, from_hf_hub=True, convert_from_torch=False)"
377-
elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder):
378-
if not support_conversion:
379-
raise EntryNotFoundError(
380-
f"can not download `{PADDLE_WEIGHTS_NAME} from https://huggingface.co/{repo_id}` "
381-
"and current model doesn't support conversion from pytorch weight file to paddle weight file"
382-
)
383-
file_name = PYTORCH_WEIGHTS_NAME
384-
else:
385-
raise EntryNotFoundError(
386-
message=f"can not find the paddle/pytorch weight file from: https://huggingface.co/{repo_id}",
387-
response=None,
388-
)
389-
else:
390-
# for local file, we use support_conversion to select paddle or torch weight.
391-
file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME
370+
file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME
392371

393372
file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME]
394373
resolved_file = None
@@ -2156,12 +2135,31 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
21562135
or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
21572136
or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
21582137
):
2159-
# try to get the name-mapping info
2160-
logger.info(
2161-
f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
2162-
f"paddle weight file<{os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)}> ..."
2138+
converted_paddle_weights = os.path.join(
2139+
os.path.dirname(resolved_archive_file), PADDLE_WEIGHTS_NAME
21632140
)
2164-
state_dict = cls.convert(resolved_archive_file, config, cache_dir)
2141+
if not os.path.exists(converted_paddle_weights):
2142+
# try to get the name-mapping info
2143+
logger.info(
2144+
f"Starting to convert pytorch weight file <{resolved_archive_file}> to "
2145+
f"paddle weight file <{converted_paddle_weights}> ..."
2146+
)
2147+
state_dict = cls.convert(resolved_archive_file, config, os.path.dirname(resolved_archive_file))
2148+
else:
2149+
# try to load the converted paddle weight file
2150+
resolved_archive_file = converted_paddle_weights
2151+
sharded_metadata = None
2152+
is_sharded = False
2153+
logger.info(
2154+
f"Detect the converted Paddle weight file <{converted_paddle_weights}>. We intend to reuse this file."
2155+
)
2156+
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
2157+
"model_state.pdparams"
2158+
):
2159+
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
2160+
else:
2161+
state_dict = load_state_dict(resolved_archive_file)
2162+
logger.info("Loaded weights file from disk, setting weights to model.")
21652163
else:
21662164
raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
21672165
else:

paddlenlp/transformers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def cached_file_for_hf_hub(
587587
download_check(path_or_repo_id, full_filename, addition="from_hf_hub")
588588
resolved_file = hf_hub_download(
589589
repo_id=path_or_repo_id,
590-
filename=full_filename,
590+
filename=filename,
591591
cache_dir=cache_dir,
592592
subfolder=subfolder,
593593
library_name="PaddleNLP",

0 commit comments

Comments
 (0)