Skip to content

Commit 582af9b

Browse files
committed
non_blocking does not matter for dtype cast
1 parent 591655e commit 582af9b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ def load_model_dict_into_meta(
245245
if keep_in_fp32_modules is not None and any(
246246
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
247247
):
248-
param = param.to(torch.float32, non_blocking=True)
248+
param = param.to(torch.float32)
249249
set_module_kwargs["dtype"] = torch.float32
250250
# For quantizers have save weights using torch.float8_e4m3fn
251251
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
252252
pass
253253
else:
254-
param = param.to(dtype, non_blocking=True)
254+
param = param.to(dtype)
255255
set_module_kwargs["dtype"] = dtype
256256

257257
if is_accelerate_version(">=", "1.9.0.dev0"):
@@ -271,7 +271,7 @@ def load_model_dict_into_meta(
271271

272272
if old_param is not None:
273273
if dtype is None:
274-
param = param.to(old_param.dtype, non_blocking=True)
274+
param = param.to(old_param.dtype)
275275

276276
if old_param.is_contiguous():
277277
param = param.contiguous()

0 commit comments

Comments
 (0)