Skip to content

Commit b776aaa

Browse files
committed
pin accelerate version
1 parent 20b1155 commit b776aaa

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_get_model_file,
4040
deprecate,
4141
is_accelerate_available,
42+
is_accelerate_version,
4243
is_gguf_available,
4344
is_torch_available,
4445
is_torch_version,
@@ -253,6 +254,10 @@ def load_model_dict_into_meta(
253254
param = param.to(dtype, non_blocking=True)
254255
set_module_kwargs["dtype"] = dtype
255256

257+
if is_accelerate_version(">=", "1.9.0.dev0"):
258+
set_module_kwargs["non_blocking"] = True
259+
set_module_kwargs["_empty_cache"] = False
260+
256261
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
257262
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
258263
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -300,15 +305,7 @@ def load_model_dict_into_meta(
300305
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
301306
)
302307
else:
303-
set_module_tensor_to_device(
304-
model,
305-
param_name,
306-
param_device,
307-
value=param,
308-
non_blocking=True,
309-
_empty_cache=False,
310-
**set_module_kwargs,
311-
)
308+
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
312309

313310
return offload_index, state_dict_index
314311

0 commit comments

Comments
 (0)