|
16 | 16 |
|
17 | 17 | import importlib |
18 | 18 | import inspect |
19 | | -import math |
20 | 19 | import os |
21 | 20 | from array import array |
22 | 21 | from collections import OrderedDict, defaultdict |
@@ -559,27 +558,33 @@ def _expand_device_map(device_map, param_names): |
559 | 558 |
|
560 | 559 |
|
561 | 560 | # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 |
562 | | -def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: |
| 561 | +def _caching_allocator_warmup( |
| 562 | + model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer] |
| 563 | +) -> None: |
563 | 564 | """ |
564 | 565 | This function warm-ups the caching allocator based on the size of the model tensors that will reside on each |
565 | 566 | device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, |
566 | 567 | which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a |
567 | 568 | very large margin. |
568 | 569 | """ |
| 570 | + factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() |
569 | 571 | # Remove disk and cpu devices, and cast to proper torch.device |
570 | 572 | accelerator_device_map = { |
571 | 573 | param: torch.device(device) |
572 | 574 | for param, device in expanded_device_map.items() |
573 | 575 | if str(device) not in ["cpu", "disk"] |
574 | 576 | } |
575 | | - parameter_count = defaultdict(lambda: 0) |
| 577 | + total_byte_count = defaultdict(lambda: 0) |
576 | 578 | for param_name, device in accelerator_device_map.items(): |
577 | 579 | try: |
578 | 580 | param = model.get_parameter(param_name) |
579 | 581 | except AttributeError: |
580 | 582 | param = model.get_buffer(param_name) |
581 | | - parameter_count[device] += math.prod(param.shape) |
| 583 | + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` |
| 584 | + param_byte_count = param.numel() * param.element_size() |
| 585 | + # TODO: account for TP when needed. |
| 586 | + total_byte_count[device] += param_byte_count |
582 | 587 |
|
583 | 588 | # This will kick off the caching allocator to avoid having to Malloc afterwards |
584 | | - for device, param_count in parameter_count.items(): |
585 | | - _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) |
| 589 | + for device, byte_count in total_byte_count.items(): |
| 590 | + _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) |
0 commit comments