|
18 | 18 | import inspect |
19 | 19 | import os |
20 | 20 | from array import array |
21 | | -from collections import OrderedDict |
| 21 | +from collections import OrderedDict, defaultdict |
22 | 22 | from pathlib import Path |
23 | 23 | from typing import Dict, List, Optional, Union |
24 | 24 | from zipfile import is_zipfile |
|
38 | 38 | _get_model_file, |
39 | 39 | deprecate, |
40 | 40 | is_accelerate_available, |
| 41 | + is_accelerator_device, |
41 | 42 | is_gguf_available, |
42 | 43 | is_torch_available, |
43 | 44 | is_torch_version, |
@@ -304,6 +305,51 @@ def load_model_dict_into_meta( |
304 | 305 | return offload_index, state_dict_index |
305 | 306 |
|
306 | 307 |
|
| 308 | +# Taken from |
| 309 | +# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26 |
| 310 | +def _expand_device_map(device_map, param_names): |
| 311 | + new_device_map = {} |
| 312 | + for module, device in device_map.items(): |
| 313 | + new_device_map.update( |
| 314 | + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} |
| 315 | + ) |
| 316 | + return new_device_map |
| 317 | + |
| 318 | + |
| 319 | +# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874 |
| 320 | +# We don't incorporate the `tp_plan` stuff as we don't support it yet. |
| 321 | +def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict: |
| 322 | + # Remove disk, cpu and meta devices, and cast to proper torch.device |
| 323 | + accelerator_device_map = { |
| 324 | + param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device) |
| 325 | + } |
| 326 | + if not len(accelerator_device_map): |
| 327 | + return |
| 328 | + |
| 329 | + total_byte_count = defaultdict(lambda: 0) |
| 330 | + for param_name, device in accelerator_device_map.items(): |
| 331 | + param = model.get_parameter_or_buffer(param_name) |
| 332 | + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` |
| 333 | + param_byte_count = param.numel() * param.element_size() |
| 334 | + total_byte_count[device] += param_byte_count |
| 335 | + |
| 336 | + # This will kick off the caching allocator to avoid having to Malloc afterwards |
| 337 | + for device, byte_count in total_byte_count.items(): |
| 338 | + if device.type == "cuda": |
| 339 | + index = device.index if device.index is not None else torch.cuda.current_device() |
| 340 | + device_memory = torch.cuda.mem_get_info(index)[0] |
| 341 | + # Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more |
| 342 | + # than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large, |
| 343 | + # and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all |
| 344 | + # the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead |
| 345 | + # to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details. |
| 346 | + # Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much |
| 347 | + # if using e.g. 90% of device size, while a 140GiB device would allocate too little |
| 348 | + byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3))) |
| 349 | + # Allocate memory |
| 350 | + _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False) |
| 351 | + |
| 352 | + |
307 | 353 | def _load_state_dict_into_model( |
308 | 354 | model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False |
309 | 355 | ) -> List[str]: |
|
0 commit comments