Skip to content

Commit 58bf268

Browse files
authored
support hf_quantizer in cache warmup. (#12043)
* support hf_quantizer in cache warmup. * reviewer feedback * up * up
1 parent 1b48db4 commit 58bf268

File tree

4 files changed

+50
-9
lines changed

4 files changed

+50
-9
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import functools
1818
import importlib
1919
import inspect
20-
import math
2120
import os
2221
from array import array
2322
from collections import OrderedDict, defaultdict
@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names):
717716

718717

719718
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
720-
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
719+
def _caching_allocator_warmup(
720+
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
721+
) -> None:
721722
"""
722723
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
723724
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
724725
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
725726
very large margin.
726727
"""
728+
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
727729
# Remove disk and cpu devices, and cast to proper torch.device
728730
accelerator_device_map = {
729731
param: torch.device(device)
730732
for param, device in expanded_device_map.items()
731733
if str(device) not in ["cpu", "disk"]
732734
}
733-
parameter_count = defaultdict(lambda: 0)
735+
total_byte_count = defaultdict(lambda: 0)
734736
for param_name, device in accelerator_device_map.items():
735737
try:
736738
param = model.get_parameter(param_name)
737739
except AttributeError:
738740
param = model.get_buffer(param_name)
739-
parameter_count[device] += math.prod(param.shape)
741+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742+
param_byte_count = param.numel() * param.element_size()
743+
# TODO: account for TP when needed.
744+
total_byte_count[device] += param_byte_count
740745

741746
# This will kick off the caching allocator to avoid having to Malloc afterwards
742-
for device, param_count in parameter_count.items():
743-
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
747+
for device, byte_count in total_byte_count.items():
748+
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,10 +1532,9 @@ def _load_pretrained_model(
15321532
# tensors using their expected shape and not performing any initialization of the memory (empty data).
15331533
# When the actual device allocations happen, the allocator already has a pool of unused device memory
15341534
# that it can re-use for faster loading of the model.
1535-
# TODO: add support for warmup with hf_quantizer
1536-
if device_map is not None and hf_quantizer is None:
1535+
if device_map is not None:
15371536
expanded_device_map = _expand_device_map(device_map, expected_keys)
1538-
_caching_allocator_warmup(model, expanded_device_map, dtype)
1537+
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
15391538

15401539
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
15411540
state_dict_folder, state_dict_index = None, None

src/diffusers/quantizers/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ def dequantize(self, model):
209209

210210
return model
211211

212+
def get_cuda_warm_up_factor(self):
213+
"""
214+
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
215+
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
216+
we allocate half the memory of the weights residing in the empty model, etc...
217+
"""
218+
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
219+
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
220+
# weight loading)
221+
return 4
222+
212223
def _dequantize(self, model):
213224
raise NotImplementedError(
214225
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import importlib
2121
import types
22+
from fnmatch import fnmatch
2223
from typing import TYPE_CHECKING, Any, Dict, List, Union
2324

2425
from packaging import version
@@ -278,6 +279,31 @@ def create_quantized_param(
278279
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
279280
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
280281

282+
def get_cuda_warm_up_factor(self):
283+
"""
284+
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
285+
- A factor of 2 means we pre-allocate the full memory footprint of the model.
286+
- A factor of 4 means we pre-allocate half of that, and so on
287+
288+
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
289+
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
290+
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
291+
torch_dtype not the actual bit-width of the quantized data.
292+
293+
To correct for this:
294+
- Use a division factor of 8 for int4 weights
295+
- Use a division factor of 4 for int8 weights
296+
"""
297+
# Original mapping for non-AOBaseConfig types
298+
# For the uint types, this is a best guess. Once these types become more used
299+
# we can look into their nuances.
300+
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
301+
quant_type = self.quantization_config.quant_type
302+
for pattern, target_dtype in map_to_target_dtype.items():
303+
if fnmatch(quant_type, pattern):
304+
return target_dtype
305+
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
306+
281307
def _process_model_before_weight_loading(
282308
self,
283309
model: "ModelMixin",

0 commit comments

Comments
 (0)