|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 The HuggingFace Inc. team. |
| 3 | +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +import inspect |
| 18 | +import os |
| 19 | +from collections import OrderedDict |
| 20 | +from typing import List, Optional, Union |
| 21 | + |
| 22 | +import safetensors |
| 23 | +import torch |
| 24 | + |
| 25 | +from ..utils import ( |
| 26 | + SAFETENSORS_FILE_EXTENSION, |
| 27 | + is_accelerate_available, |
| 28 | + is_torch_version, |
| 29 | + logging, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +logger = logging.get_logger(__name__) |
| 34 | + |
| 35 | + |
| 36 | +if is_accelerate_available(): |
| 37 | + from accelerate import infer_auto_device_map |
| 38 | + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device |
| 39 | + |
| 40 | + |
| 41 | +# Adapted from `transformers` (see modeling_utils.py) |
| 42 | +def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): |
| 43 | + if isinstance(device_map, str): |
| 44 | + no_split_modules = model._get_no_split_modules(device_map) |
| 45 | + device_map_kwargs = {"no_split_module_classes": no_split_modules} |
| 46 | + |
| 47 | + if device_map != "sequential": |
| 48 | + max_memory = get_balanced_memory( |
| 49 | + model, |
| 50 | + dtype=torch_dtype, |
| 51 | + low_zero=(device_map == "balanced_low_0"), |
| 52 | + max_memory=max_memory, |
| 53 | + **device_map_kwargs, |
| 54 | + ) |
| 55 | + else: |
| 56 | + max_memory = get_max_memory(max_memory) |
| 57 | + |
| 58 | + device_map_kwargs["max_memory"] = max_memory |
| 59 | + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) |
| 60 | + |
| 61 | + return device_map |
| 62 | + |
| 63 | + |
| 64 | +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): |
| 65 | + """ |
| 66 | + Reads a checkpoint file, returning properly formatted errors if they arise. |
| 67 | + """ |
| 68 | + try: |
| 69 | + file_extension = os.path.basename(checkpoint_file).split(".")[-1] |
| 70 | + if file_extension == SAFETENSORS_FILE_EXTENSION: |
| 71 | + return safetensors.torch.load_file(checkpoint_file, device="cpu") |
| 72 | + else: |
| 73 | + weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} |
| 74 | + return torch.load( |
| 75 | + checkpoint_file, |
| 76 | + map_location="cpu", |
| 77 | + **weights_only_kwarg, |
| 78 | + ) |
| 79 | + except Exception as e: |
| 80 | + try: |
| 81 | + with open(checkpoint_file) as f: |
| 82 | + if f.read().startswith("version"): |
| 83 | + raise OSError( |
| 84 | + "You seem to have cloned a repository without having git-lfs installed. Please install " |
| 85 | + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
| 86 | + "you cloned." |
| 87 | + ) |
| 88 | + else: |
| 89 | + raise ValueError( |
| 90 | + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " |
| 91 | + "model. Make sure you have saved the model properly." |
| 92 | + ) from e |
| 93 | + except (UnicodeDecodeError, ValueError): |
| 94 | + raise OSError( |
| 95 | + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +def load_model_dict_into_meta( |
| 100 | + model, |
| 101 | + state_dict: OrderedDict, |
| 102 | + device: Optional[Union[str, torch.device]] = None, |
| 103 | + dtype: Optional[Union[str, torch.dtype]] = None, |
| 104 | + model_name_or_path: Optional[str] = None, |
| 105 | +) -> List[str]: |
| 106 | + device = device or torch.device("cpu") |
| 107 | + dtype = dtype or torch.float32 |
| 108 | + |
| 109 | + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) |
| 110 | + |
| 111 | + unexpected_keys = [] |
| 112 | + empty_state_dict = model.state_dict() |
| 113 | + for param_name, param in state_dict.items(): |
| 114 | + if param_name not in empty_state_dict: |
| 115 | + unexpected_keys.append(param_name) |
| 116 | + continue |
| 117 | + |
| 118 | + if empty_state_dict[param_name].shape != param.shape: |
| 119 | + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" |
| 120 | + raise ValueError( |
| 121 | + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." |
| 122 | + ) |
| 123 | + |
| 124 | + if accepts_dtype: |
| 125 | + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) |
| 126 | + else: |
| 127 | + set_module_tensor_to_device(model, param_name, device, value=param) |
| 128 | + return unexpected_keys |
| 129 | + |
| 130 | + |
| 131 | +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: |
| 132 | + # Convert old format to new format if needed from a PyTorch state_dict |
| 133 | + # copy state_dict so _load_from_state_dict can modify it |
| 134 | + state_dict = state_dict.copy() |
| 135 | + error_msgs = [] |
| 136 | + |
| 137 | + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants |
| 138 | + # so we need to apply the function recursively. |
| 139 | + def load(module: torch.nn.Module, prefix: str = ""): |
| 140 | + args = (state_dict, prefix, {}, True, [], [], error_msgs) |
| 141 | + module._load_from_state_dict(*args) |
| 142 | + |
| 143 | + for name, child in module._modules.items(): |
| 144 | + if child is not None: |
| 145 | + load(child, prefix + name + ".") |
| 146 | + |
| 147 | + load(model_to_load) |
| 148 | + |
| 149 | + return error_msgs |
0 commit comments