File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -798,7 +798,7 @@ def _load_state_dict_into_meta_model(
798
798
799
799
dtype = convert_np_dtype_to_dtype_ (dtype )
800
800
error_msgs = []
801
-
801
+ model_state_dict = model . state_dict ()
802
802
for param_name , param in state_dict .items ():
803
803
# First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
804
804
if param_name not in loaded_state_dict_keys or param_name not in expected_keys :
@@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
833
833
if old_param is not None :
834
834
param = param .astype (dtype = old_param .dtype )
835
835
with paddle .no_grad ():
836
- model . state_dict () [param_name ].get_tensor ()._share_data_with (param .value ().get_tensor ())
836
+ model_state_dict [param_name ].get_tensor ()._share_data_with (param .value ().get_tensor ())
837
837
param .value ().get_tensor ()._clear ()
838
838
return error_msgs
839
839
@@ -1890,7 +1890,7 @@ def _find_mismatched_keys(
1890
1890
if (
1891
1891
shard_file .endswith (".safetensors" )
1892
1892
and config .tensor_parallel_degree > 1
1893
- and "tp" not in shard_file
1893
+ and "tp" not in os . path . spilt ( shard_file )[ - 1 ]
1894
1894
):
1895
1895
pre_tensor_parallel_split = True
1896
1896
assert loaded_keys is not None , "loaded_keys is not None."
You can’t perform that action at this time.
0 commit comments