|
16 | 16 |
|
17 | 17 | import importlib |
18 | 18 | import inspect |
| 19 | +import math |
19 | 20 | import os |
20 | 21 | from array import array |
21 | | -from collections import OrderedDict |
| 22 | +from collections import OrderedDict, defaultdict |
22 | 23 | from pathlib import Path |
23 | 24 | from typing import Dict, List, Optional, Union |
24 | 25 | from zipfile import is_zipfile |
|
38 | 39 | _get_model_file, |
39 | 40 | deprecate, |
40 | 41 | is_accelerate_available, |
| 42 | + is_accelerate_version, |
41 | 43 | is_gguf_available, |
42 | 44 | is_torch_available, |
43 | 45 | is_torch_version, |
@@ -252,6 +254,10 @@ def load_model_dict_into_meta( |
252 | 254 | param = param.to(dtype) |
253 | 255 | set_module_kwargs["dtype"] = dtype |
254 | 256 |
|
| 257 | + if is_accelerate_version(">", "1.8.1"): |
| 258 | + set_module_kwargs["non_blocking"] = True |
| 259 | + set_module_kwargs["clear_cache"] = False |
| 260 | + |
255 | 261 | # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which |
256 | 262 | # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. |
257 | 263 | # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 |
@@ -520,3 +526,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): |
520 | 526 | parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights |
521 | 527 |
|
522 | 528 | return parsed_parameters |
| 529 | + |
| 530 | + |
| 531 | +def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes): |
| 532 | + mismatched_keys = [] |
| 533 | + if not ignore_mismatched_sizes: |
| 534 | + return mismatched_keys |
| 535 | + for checkpoint_key in loaded_keys: |
| 536 | + model_key = checkpoint_key |
| 537 | + # If the checkpoint is sharded, we may not have the key here. |
| 538 | + if checkpoint_key not in state_dict: |
| 539 | + continue |
| 540 | + |
| 541 | + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: |
| 542 | + mismatched_keys.append( |
| 543 | + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) |
| 544 | + ) |
| 545 | + del state_dict[checkpoint_key] |
| 546 | + return mismatched_keys |
| 547 | + |
| 548 | + |
| 549 | +def _expand_device_map(device_map, param_names): |
| 550 | + """ |
| 551 | + Expand a device map to return the correspondence parameter name to device. |
| 552 | + """ |
| 553 | + new_device_map = {} |
| 554 | + for module, device in device_map.items(): |
| 555 | + new_device_map.update( |
| 556 | + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} |
| 557 | + ) |
| 558 | + return new_device_map |
| 559 | + |
| 560 | + |
| 561 | +# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 |
| 562 | +def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: |
| 563 | + """ |
| 564 | + This function warm-ups the caching allocator based on the size of the model tensors that will reside on each |
| 565 | + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, |
| 566 | + which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a |
| 567 | + very large margin. |
| 568 | + """ |
| 569 | + # Remove disk and cpu devices, and cast to proper torch.device |
| 570 | + accelerator_device_map = { |
| 571 | + param: torch.device(device) |
| 572 | + for param, device in expanded_device_map.items() |
| 573 | + if str(device) not in ["cpu", "disk"] |
| 574 | + } |
| 575 | + parameter_count = defaultdict(lambda: 0) |
| 576 | + for param_name, device in accelerator_device_map.items(): |
| 577 | + try: |
| 578 | + param = model.get_parameter(param_name) |
| 579 | + except AttributeError: |
| 580 | + param = model.get_buffer(param_name) |
| 581 | + parameter_count[device] += math.prod(param.shape) |
| 582 | + |
| 583 | + # This will kick off the caching allocator to avoid having to Malloc afterwards |
| 584 | + for device, param_count in parameter_count.items(): |
| 585 | + _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) |
0 commit comments