@@ -1064,9 +1064,26 @@ def _find_mismatched_keys(
1064
1064
raise ValueError (
1065
1065
f"the value of `dtype` should be one of [`float32`, `float16`], but received { dtype } "
1066
1066
)
1067
- for key in state_to_load .keys ():
1068
- state_to_load [key ] = paddle .cast (state_to_load [key ],
1069
- dtype = dtype )
1067
+ for key in state_dict .keys ():
1068
+ state_dict [key ] = paddle .cast (state_dict [key ], dtype = dtype )
1069
+ else :
1070
+ dtype_prefix_len = len ("paddle." )
1071
+ for k , v in model_to_load .state_dict ().items ():
1072
+ if not isinstance (v , np .ndarray ):
1073
+ dtype = str (v .dtype )[dtype_prefix_len :]
1074
+ if k in state_dict :
1075
+ if paddle .in_dynamic_mode ():
1076
+ if isinstance (state_dict [k ], np .ndarray ):
1077
+ state_dict [k ] = state_dict [k ].astype (dtype )
1078
+ else :
1079
+ state_dict [k ] = paddle .cast (state_dict [k ], dtype )
1080
+ else :
1081
+ # there are some latent error when case dtype in static-mode, so let's:
1082
+ # 1. convert fluid.*.Tensor -> numpy.ndarray
1083
+ # 2. cast the dtype with numpy tools
1084
+ # 3. paddle works well with ndarray state-dict
1085
+ state_dict [k ] = np .array (state_dict [k ])
1086
+ state_dict [k ] = state_dict [k ].astype (dtype )
1070
1087
1071
1088
# For model parallel if FasterGeneration
1072
1089
# To avoid recursive import temporarily.
0 commit comments