Skip to content

Commit 8879f79

Browse files
authored
quick fix from pretrained. (#8487)
1 parent 7a24bcc commit 8879f79

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def _load_state_dict_into_meta_model(
798798

799799
dtype = convert_np_dtype_to_dtype_(dtype)
800800
error_msgs = []
801-
801+
model_state_dict = model.state_dict()
802802
for param_name, param in state_dict.items():
803803
# First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
804804
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
@@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
833833
if old_param is not None:
834834
param = param.astype(dtype=old_param.dtype)
835835
with paddle.no_grad():
836-
model.state_dict()[param_name].get_tensor()._share_data_with(param.value().get_tensor())
836+
model_state_dict[param_name].get_tensor()._share_data_with(param.value().get_tensor())
837837
param.value().get_tensor()._clear()
838838
return error_msgs
839839

@@ -1890,7 +1890,7 @@ def _find_mismatched_keys(
18901890
if (
18911891
shard_file.endswith(".safetensors")
18921892
and config.tensor_parallel_degree > 1
1893-
and "tp" not in shard_file
1893+
and "tp" not in os.path.spilt(shard_file)[-1]
18941894
):
18951895
pre_tensor_parallel_split = True
18961896
assert loaded_keys is not None, "loaded_keys is not None."

0 commit comments

Comments
 (0)