Skip to content

Commit 7a9c448

Browse files
committed
update
1 parent 754fe85 commit 7a9c448

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import importlib
1818
import inspect
19+
import math
1920
import os
2021
from array import array
21-
from collections import OrderedDict
22+
from collections import OrderedDict, defaultdict
2223
from pathlib import Path
2324
from typing import Dict, List, Optional, Union
2425
from zipfile import is_zipfile
@@ -230,6 +231,16 @@ def load_model_dict_into_meta(
230231

231232
is_quantized = hf_quantizer is not None
232233
empty_state_dict = model.state_dict()
234+
expanded_device_map = {}
235+
236+
if device_map is not None:
237+
for param_name, param in state_dict.items():
238+
if param_name not in empty_state_dict:
239+
continue
240+
param_device = _determine_param_device(param_name, device_map)
241+
expanded_device_map[param_name] = param_device
242+
print(expanded_device_map)
243+
_caching_allocator_warmup(model, expanded_device_map, dtype)
233244

234245
for param_name, param in state_dict.items():
235246
if param_name not in empty_state_dict:
@@ -243,13 +254,13 @@ def load_model_dict_into_meta(
243254
if keep_in_fp32_modules is not None and any(
244255
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
245256
):
246-
param = param.to(torch.float32)
257+
param = param.to(torch.float32, non_blocking=True)
247258
set_module_kwargs["dtype"] = torch.float32
248259
# For quantizers have save weights using torch.float8_e4m3fn
249260
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
250261
pass
251262
else:
252-
param = param.to(dtype)
263+
param = param.to(dtype, non_blocking=True)
253264
set_module_kwargs["dtype"] = dtype
254265

255266
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
@@ -265,7 +276,7 @@ def load_model_dict_into_meta(
265276

266277
if old_param is not None:
267278
if dtype is None:
268-
param = param.to(old_param.dtype)
279+
param = param.to(old_param.dtype, non_blocking=True)
269280

270281
if old_param.is_contiguous():
271282
param = param.contiguous()
@@ -520,3 +531,29 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
520531
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
521532

522533
return parsed_parameters
534+
535+
536+
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
537+
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
538+
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
539+
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
540+
which is actually the loading speed botteneck. Calling this function allows to cut the model loading time by a very
541+
large margin.
542+
"""
543+
# Remove disk and cpu devices, and cast to proper torch.device
544+
accelerator_device_map = {
545+
param: torch.device(device)
546+
for param, device in expanded_device_map.items()
547+
if str(device) not in ["cpu", "disk"]
548+
}
549+
parameter_count = defaultdict(lambda: 0)
550+
for param_name, device in accelerator_device_map.items():
551+
try:
552+
param = model.get_parameter(param_name)
553+
except AttributeError:
554+
param = model.get_buffer(param_name)
555+
parameter_count[device] += math.prod(param.shape)
556+
557+
# This will kick off the caching allocator to avoid having to Malloc afterwards
558+
for device, param_count in parameter_count.items():
559+
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)

0 commit comments

Comments
 (0)