Skip to content

Commit 14e6813

Browse files
gongenleiwj-Mcat
andauthored
fix load_fp16 model (#3902)
Co-authored-by: 骑马小猫 <[email protected]>
1 parent cef8291 commit 14e6813

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,26 @@ def _find_mismatched_keys(
10641064
raise ValueError(
10651065
f"the value of `dtype` should be one of [`float32`, `float16`], but received {dtype}"
10661066
)
1067-
for key in state_to_load.keys():
1068-
state_to_load[key] = paddle.cast(state_to_load[key],
1069-
dtype=dtype)
1067+
for key in state_dict.keys():
1068+
state_dict[key] = paddle.cast(state_dict[key], dtype=dtype)
1069+
else:
1070+
dtype_prefix_len = len("paddle.")
1071+
for k, v in model_to_load.state_dict().items():
1072+
if not isinstance(v, np.ndarray):
1073+
dtype = str(v.dtype)[dtype_prefix_len:]
1074+
if k in state_dict:
1075+
if paddle.in_dynamic_mode():
1076+
if isinstance(state_dict[k], np.ndarray):
1077+
state_dict[k] = state_dict[k].astype(dtype)
1078+
else:
1079+
state_dict[k] = paddle.cast(state_dict[k], dtype)
1080+
else:
1081+
# there are some latent error when case dtype in static-mode, so let's:
1082+
# 1. convert fluid.*.Tensor -> numpy.ndarray
1083+
# 2. cast the dtype with numpy tools
1084+
# 3. paddle works well with ndarray state-dict
1085+
state_dict[k] = np.array(state_dict[k])
1086+
state_dict[k] = state_dict[k].astype(dtype)
10701087

10711088
# For model parallel if FasterGeneration
10721089
# To avoid recursive import temporarily.

0 commit comments

Comments
 (0)