@@ -190,6 +190,7 @@ def load_model_dict_into_meta(
190190 if param_name not in empty_state_dict :
191191 continue
192192
193+ set_module_kwargs = {}
193194 # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
194195 # in int/uint/bool and not cast them.
195196 # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
@@ -201,10 +202,13 @@ def load_model_dict_into_meta(
201202 )
202203 and dtype == torch .float16
203204 ):
204- dtype = torch .float32
205- param = param .to (dtype )
205+ param = param .to (torch .float32 )
206+ if accepts_dtype :
207+ set_module_kwargs ["dtype" ] = torch .float32
206208 else :
207209 param = param .to (dtype )
210+ if accepts_dtype :
211+ set_module_kwargs ["dtype" ] = dtype
208212
209213 # bnb params are flattened.
210214 if not is_quant_method_bnb and empty_state_dict [param_name ].shape != param .shape :
@@ -217,7 +221,7 @@ def load_model_dict_into_meta(
217221 not hf_quantizer .check_quantized_param (model , param , param_name , state_dict , param_device = device )
218222 ):
219223 if accepts_dtype :
220- set_module_tensor_to_device (model , param_name , device , value = param , dtype = dtype )
224+ set_module_tensor_to_device (model , param_name , device , value = param , ** set_module_kwargs )
221225 else :
222226 set_module_tensor_to_device (model , param_name , device , value = param )
223227 else :
0 commit comments