Skip to content

Commit 51d1436

Browse files
committed
faster model loading on cuda.
1 parent 0dec414 commit 51d1436

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import os
2020
from array import array
21-
from collections import OrderedDict
21+
from collections import OrderedDict, defaultdict
2222
from pathlib import Path
2323
from typing import Dict, List, Optional, Union
2424
from zipfile import is_zipfile
@@ -38,6 +38,7 @@
3838
_get_model_file,
3939
deprecate,
4040
is_accelerate_available,
41+
is_accelerator_device,
4142
is_gguf_available,
4243
is_torch_available,
4344
is_torch_version,
@@ -304,6 +305,51 @@ def load_model_dict_into_meta(
304305
return offload_index, state_dict_index
305306

306307

308+
# Taken from
309+
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
310+
def _expand_device_map(device_map, param_names):
311+
new_device_map = {}
312+
for module, device in device_map.items():
313+
new_device_map.update(
314+
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
315+
)
316+
return new_device_map
317+
318+
319+
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
320+
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
321+
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
322+
# Remove disk, cpu and meta devices, and cast to proper torch.device
323+
accelerator_device_map = {
324+
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
325+
}
326+
if not len(accelerator_device_map):
327+
return
328+
329+
total_byte_count = defaultdict(lambda: 0)
330+
for param_name, device in accelerator_device_map.items():
331+
param = model.get_parameter_or_buffer(param_name)
332+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
333+
param_byte_count = param.numel() * param.element_size()
334+
total_byte_count[device] += param_byte_count
335+
336+
# This will kick off the caching allocator to avoid having to Malloc afterwards
337+
for device, byte_count in total_byte_count.items():
338+
if device.type == "cuda":
339+
index = device.index if device.index is not None else torch.cuda.current_device()
340+
device_memory = torch.cuda.mem_get_info(index)[0]
341+
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
342+
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
343+
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
344+
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
345+
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
346+
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
347+
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
348+
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
349+
# Allocate memory
350+
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
351+
352+
307353
def _load_state_dict_into_model(
308354
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
309355
) -> List[str]:

src/diffusers/models/modeling_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@
6363
populate_model_card,
6464
)
6565
from .model_loading_utils import (
66+
_caching_allocator_warmup,
6667
_determine_device_map,
68+
_expand_device_map,
6769
_fetch_index_file,
6870
_fetch_index_file_legacy,
6971
_load_state_dict_into_model,
@@ -1374,6 +1376,24 @@ def float(self, *args):
13741376
else:
13751377
return super().float(*args)
13761378

1379+
# Taken from `transformers`.
1380+
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
1381+
def get_parameter_or_buffer(self, target: str):
1382+
"""
1383+
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
1384+
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
1385+
of the model.
1386+
"""
1387+
try:
1388+
return self.get_parameter(target)
1389+
except AttributeError:
1390+
pass
1391+
try:
1392+
return self.get_buffer(target)
1393+
except AttributeError:
1394+
pass
1395+
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
1396+
13771397
@classmethod
13781398
def _load_pretrained_model(
13791399
cls,
@@ -1410,6 +1430,11 @@ def _load_pretrained_model(
14101430
assign_to_params_buffers = None
14111431
error_msgs = []
14121432

1433+
# Optionally, warmup cuda to load the weights much faster on devices
1434+
if device_map is not None:
1435+
expanded_device_map = _expand_device_map(device_map, expected_keys)
1436+
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
1437+
14131438
# Deal with offload
14141439
if device_map is not None and "disk" in device_map.values():
14151440
if offload_folder is None:

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
convert_unet_state_dict_to_peft,
130130
state_dict_all_zero,
131131
)
132+
from .testing_utils import is_accelerator_device
132133
from .typing_utils import _get_detailed_type, _is_valid_type
133134

134135

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,18 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
12891289
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
12901290

12911291

1292+
if is_torch_available():
1293+
# Taken from
1294+
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
1295+
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
1296+
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
1297+
a proper `torch.device`.
1298+
"""
1299+
if device == "disk":
1300+
return False
1301+
else:
1302+
return torch.device(device).type not in ["meta", "cpu"]
1303+
12921304
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
12931305

12941306
# Type definition of key used in `Expectations` class.

0 commit comments

Comments
 (0)