|
52 | 52 | from paddle.utils.download import is_url as is_remote_url
|
53 | 53 | from tqdm.auto import tqdm
|
54 | 54 |
|
55 |
| -from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists |
| 55 | +from paddlenlp.utils.downloader import get_path_from_url_with_filelock |
56 | 56 | from paddlenlp.utils.env import (
|
57 | 57 | CONFIG_NAME,
|
58 | 58 | LEGACY_CONFIG_NAME,
|
@@ -367,28 +367,7 @@ def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conver
|
367 | 367 | support_conversion (bool): whether support converting pytorch weight file to paddle weight file
|
368 | 368 | subfolder (str, optional) An optional value corresponding to a folder inside the repo.
|
369 | 369 | """
|
370 |
| - is_local = os.path.isdir(repo_id) |
371 |
| - if not is_local: |
372 |
| - if hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder): |
373 |
| - file_name = PADDLE_WEIGHTS_NAME |
374 |
| - assert ( |
375 |
| - support_conversion is False |
376 |
| - ), "Please call set convert_from_torch for paddle weights on huggingface hub, eg. Model.from_pretrained(model_name, from_hf_hub=True, convert_from_torch=False)" |
377 |
| - elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder): |
378 |
| - if not support_conversion: |
379 |
| - raise EntryNotFoundError( |
380 |
| - f"can not download `{PADDLE_WEIGHTS_NAME} from https://huggingface.co/{repo_id}` " |
381 |
| - "and current model doesn't support conversion from pytorch weight file to paddle weight file" |
382 |
| - ) |
383 |
| - file_name = PYTORCH_WEIGHTS_NAME |
384 |
| - else: |
385 |
| - raise EntryNotFoundError( |
386 |
| - message=f"can not find the paddle/pytorch weight file from: https://huggingface.co/{repo_id}", |
387 |
| - response=None, |
388 |
| - ) |
389 |
| - else: |
390 |
| - # for local file, we use support_conversion to select paddle or torch weight. |
391 |
| - file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME |
| 370 | + file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME |
392 | 371 |
|
393 | 372 | file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME]
|
394 | 373 | resolved_file = None
|
@@ -2156,12 +2135,31 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
2156 | 2135 | or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
|
2157 | 2136 | or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
|
2158 | 2137 | ):
|
2159 |
| - # try to get the name-mapping info |
2160 |
| - logger.info( |
2161 |
| - f"Starting to convert pytorch weight file<{resolved_archive_file}> to " |
2162 |
| - f"paddle weight file<{os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)}> ..." |
| 2138 | + converted_paddle_weights = os.path.join( |
| 2139 | + os.path.dirname(resolved_archive_file), PADDLE_WEIGHTS_NAME |
2163 | 2140 | )
|
2164 |
| - state_dict = cls.convert(resolved_archive_file, config, cache_dir) |
| 2141 | + if not os.path.exists(converted_paddle_weights): |
| 2142 | + # try to get the name-mapping info |
| 2143 | + logger.info( |
| 2144 | + f"Starting to convert pytorch weight file <{resolved_archive_file}> to " |
| 2145 | + f"paddle weight file <{converted_paddle_weights}> ..." |
| 2146 | + ) |
| 2147 | + state_dict = cls.convert(resolved_archive_file, config, os.path.dirname(resolved_archive_file)) |
| 2148 | + else: |
| 2149 | + # try to load the converted paddle weight file |
| 2150 | + resolved_archive_file = converted_paddle_weights |
| 2151 | + sharded_metadata = None |
| 2152 | + is_sharded = False |
| 2153 | + logger.info( |
| 2154 | + f"Detect the converted Paddle weight file <{converted_paddle_weights}>. We intend to reuse this file." |
| 2155 | + ) |
| 2156 | + if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith( |
| 2157 | + "model_state.pdparams" |
| 2158 | + ): |
| 2159 | + state_dict = cls.convert_tensor_parallel(resolved_archive_file, config) |
| 2160 | + else: |
| 2161 | + state_dict = load_state_dict(resolved_archive_file) |
| 2162 | + logger.info("Loaded weights file from disk, setting weights to model.") |
2165 | 2163 | else:
|
2166 | 2164 | raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
|
2167 | 2165 | else:
|
|
0 commit comments