Skip to content

Commit ff8ddef

Browse files
committed
handle dtype more robustly.
1 parent 8bdc846 commit ff8ddef

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def load_model_dict_into_meta(
190190
if param_name not in empty_state_dict:
191191
continue
192192

193+
set_module_kwargs = {}
193194
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
194195
# in int/uint/bool and not cast them.
195196
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
@@ -201,10 +202,13 @@ def load_model_dict_into_meta(
201202
)
202203
and dtype == torch.float16
203204
):
204-
dtype = torch.float32
205-
param = param.to(dtype)
205+
param = param.to(torch.float32)
206+
if accepts_dtype:
207+
set_module_kwargs["dtype"] = torch.float32
206208
else:
207209
param = param.to(dtype)
210+
if accepts_dtype:
211+
set_module_kwargs["dtype"] = dtype
208212

209213
# bnb params are flattened.
210214
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
@@ -217,7 +221,7 @@ def load_model_dict_into_meta(
217221
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
218222
):
219223
if accepts_dtype:
220-
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
224+
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
221225
else:
222226
set_module_tensor_to_device(model, param_name, device, value=param)
223227
else:

0 commit comments

Comments
 (0)