|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from collections import OrderedDict |
| 16 | +from itertools import combinations |
| 17 | +from typing import List, Optional, Union |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | +from ..utils import ( |
| 22 | + is_accelerate_available, |
| 23 | + logging, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +if is_accelerate_available(): |
| 28 | + from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module |
| 29 | + from accelerate.state import PartialState |
| 30 | + from accelerate.utils import send_to_device |
| 31 | + from accelerate.utils.memory import clear_device_cache |
| 32 | + |
| 33 | +logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 34 | + |
| 35 | + |
| 36 | +# YiYi Notes: copied from modeling_utils.py (decide later where to put this) |
| 37 | +def get_memory_footprint(self, return_buffers=True): |
| 38 | + r""" |
| 39 | + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to |
| 40 | + benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch |
| 41 | + discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 |
| 42 | +
|
| 43 | + Arguments: |
| 44 | + return_buffers (`bool`, *optional*, defaults to `True`): |
| 45 | + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are |
| 46 | + tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm |
| 47 | + layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 |
| 48 | + """ |
| 49 | + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) |
| 50 | + if return_buffers: |
| 51 | + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) |
| 52 | + mem = mem + mem_bufs |
| 53 | + return mem |
| 54 | + |
| 55 | + |
| 56 | +class CustomOffloadHook(ModelHook): |
| 57 | + """ |
| 58 | + A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are |
| 59 | + on the given device. Optionally offloads other models to the CPU before the forward pass is called. |
| 60 | +
|
| 61 | + Args: |
| 62 | + execution_device(`str`, `int` or `torch.device`, *optional*): |
| 63 | + The device on which the model should be executed. Will default to the MPS device if it's available, then |
| 64 | + GPU 0 if there is a GPU, and finally to the CPU. |
| 65 | + """ |
| 66 | + |
| 67 | + def __init__( |
| 68 | + self, |
| 69 | + execution_device: Optional[Union[str, int, torch.device]] = None, |
| 70 | + other_hooks: Optional[List["UserCustomOffloadHook"]] = None, |
| 71 | + offload_strategy: Optional["AutoOffloadStrategy"] = None, |
| 72 | + ): |
| 73 | + self.execution_device = execution_device if execution_device is not None else PartialState().default_device |
| 74 | + self.other_hooks = other_hooks |
| 75 | + self.offload_strategy = offload_strategy |
| 76 | + self.model_id = None |
| 77 | + |
| 78 | + def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): |
| 79 | + self.offload_strategy = offload_strategy |
| 80 | + |
| 81 | + def add_other_hook(self, hook: "UserCustomOffloadHook"): |
| 82 | + """ |
| 83 | + Add a hook to the list of hooks to consider for offloading. |
| 84 | + """ |
| 85 | + if self.other_hooks is None: |
| 86 | + self.other_hooks = [] |
| 87 | + self.other_hooks.append(hook) |
| 88 | + |
| 89 | + def init_hook(self, module): |
| 90 | + return module.to("cpu") |
| 91 | + |
| 92 | + def pre_forward(self, module, *args, **kwargs): |
| 93 | + if module.device != self.execution_device: |
| 94 | + if self.other_hooks is not None: |
| 95 | + hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] |
| 96 | + # offload all other hooks |
| 97 | + import time |
| 98 | + |
| 99 | + # YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later) |
| 100 | + start_time = time.perf_counter() |
| 101 | + if self.offload_strategy is not None: |
| 102 | + hooks_to_offload = self.offload_strategy( |
| 103 | + hooks=hooks_to_offload, |
| 104 | + model_id=self.model_id, |
| 105 | + model=module, |
| 106 | + execution_device=self.execution_device, |
| 107 | + ) |
| 108 | + end_time = time.perf_counter() |
| 109 | + logger.info( |
| 110 | + f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" |
| 111 | + ) |
| 112 | + |
| 113 | + for hook in hooks_to_offload: |
| 114 | + logger.info( |
| 115 | + f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" |
| 116 | + ) |
| 117 | + hook.offload() |
| 118 | + |
| 119 | + if hooks_to_offload: |
| 120 | + clear_device_cache() |
| 121 | + module.to(self.execution_device) |
| 122 | + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) |
| 123 | + |
| 124 | + |
| 125 | +class UserCustomOffloadHook: |
| 126 | + """ |
| 127 | + A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of |
| 128 | + the hook or remove it entirely. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__(self, model_id, model, hook): |
| 132 | + self.model_id = model_id |
| 133 | + self.model = model |
| 134 | + self.hook = hook |
| 135 | + |
| 136 | + def offload(self): |
| 137 | + self.hook.init_hook(self.model) |
| 138 | + |
| 139 | + def attach(self): |
| 140 | + add_hook_to_module(self.model, self.hook) |
| 141 | + self.hook.model_id = self.model_id |
| 142 | + |
| 143 | + def remove(self): |
| 144 | + remove_hook_from_module(self.model) |
| 145 | + self.hook.model_id = None |
| 146 | + |
| 147 | + def add_other_hook(self, hook: "UserCustomOffloadHook"): |
| 148 | + self.hook.add_other_hook(hook) |
| 149 | + |
| 150 | + |
| 151 | +def custom_offload_with_hook( |
| 152 | + model_id: str, |
| 153 | + model: torch.nn.Module, |
| 154 | + execution_device: Union[str, int, torch.device] = None, |
| 155 | + offload_strategy: Optional["AutoOffloadStrategy"] = None, |
| 156 | +): |
| 157 | + hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) |
| 158 | + user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) |
| 159 | + user_hook.attach() |
| 160 | + return user_hook |
| 161 | + |
| 162 | + |
| 163 | +class AutoOffloadStrategy: |
| 164 | + """ |
| 165 | + Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on |
| 166 | + the available memory on the device. |
| 167 | + """ |
| 168 | + |
| 169 | + def __init__(self, size_estimation_margin=0.1): |
| 170 | + self.size_estimation_margin = size_estimation_margin |
| 171 | + |
| 172 | + def __call__(self, hooks, model_id, model, execution_device): |
| 173 | + if len(hooks) == 0: |
| 174 | + return [] |
| 175 | + |
| 176 | + current_module_size = get_memory_footprint(model) |
| 177 | + current_module_size *= 1 + self.size_estimation_margin |
| 178 | + |
| 179 | + mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] |
| 180 | + if current_module_size < mem_on_device: |
| 181 | + return [] |
| 182 | + |
| 183 | + min_memory_offload = current_module_size - mem_on_device |
| 184 | + logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") |
| 185 | + |
| 186 | + # exlucde models that's not currently loaded on the device |
| 187 | + module_sizes = dict( |
| 188 | + sorted( |
| 189 | + {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), |
| 190 | + key=lambda x: x[1], |
| 191 | + reverse=True, |
| 192 | + ) |
| 193 | + ) |
| 194 | + |
| 195 | + def search_best_candidate(module_sizes, min_memory_offload): |
| 196 | + """ |
| 197 | + search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a |
| 198 | + minimum memory offload size. the combination of models should add up to the smallest modulesize that is |
| 199 | + larger than `min_memory_offload` |
| 200 | + """ |
| 201 | + model_ids = list(module_sizes.keys()) |
| 202 | + best_candidate = None |
| 203 | + best_size = float("inf") |
| 204 | + for r in range(1, len(model_ids) + 1): |
| 205 | + for candidate_model_ids in combinations(model_ids, r): |
| 206 | + candidate_size = sum( |
| 207 | + module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids |
| 208 | + ) |
| 209 | + if candidate_size < min_memory_offload: |
| 210 | + continue |
| 211 | + else: |
| 212 | + if best_candidate is None or candidate_size < best_size: |
| 213 | + best_candidate = candidate_model_ids |
| 214 | + best_size = candidate_size |
| 215 | + |
| 216 | + return best_candidate |
| 217 | + |
| 218 | + best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) |
| 219 | + |
| 220 | + if best_offload_model_ids is None: |
| 221 | + # if no combination is found, meaning that we cannot meet the memory requirement, offload all models |
| 222 | + logger.warning("no combination of models to offload to cpu is found, offloading all models") |
| 223 | + hooks_to_offload = hooks |
| 224 | + else: |
| 225 | + hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] |
| 226 | + |
| 227 | + return hooks_to_offload |
| 228 | + |
| 229 | + |
| 230 | +class ModelManager: |
| 231 | + def __init__(self): |
| 232 | + self.models = OrderedDict() |
| 233 | + self.model_hooks = None |
| 234 | + self._auto_offload_enabled = False |
| 235 | + |
| 236 | + def add(self, model_id, model): |
| 237 | + if model_id not in self.models: |
| 238 | + self.models[model_id] = model |
| 239 | + if self._auto_offload_enabled: |
| 240 | + self.enable_auto_cpu_offload(self._auto_offload_device) |
| 241 | + |
| 242 | + def remove(self, model_id): |
| 243 | + self.models.pop(model_id) |
| 244 | + if self._auto_offload_enabled: |
| 245 | + self.enable_auto_cpu_offload(self._auto_offload_device) |
| 246 | + |
| 247 | + def enable_auto_cpu_offload(self, device, size_estimation_margin=0.1): |
| 248 | + for model_id, model in self.models.items(): |
| 249 | + if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): |
| 250 | + remove_hook_from_module(model, recurse=True) |
| 251 | + |
| 252 | + self.disable_auto_cpu_offload() |
| 253 | + offload_strategy = AutoOffloadStrategy(size_estimation_margin=size_estimation_margin) |
| 254 | + device = torch.device(device) |
| 255 | + if device.index is None: |
| 256 | + device = torch.device(f"{device.type}:{0}") |
| 257 | + all_hooks = [] |
| 258 | + for model_id, model in self.models.items(): |
| 259 | + hook = custom_offload_with_hook(model_id, model, device, offload_strategy=offload_strategy) |
| 260 | + all_hooks.append(hook) |
| 261 | + |
| 262 | + for hook in all_hooks: |
| 263 | + other_hooks = [h for h in all_hooks if h is not hook] |
| 264 | + for other_hook in other_hooks: |
| 265 | + if other_hook.hook.execution_device == hook.hook.execution_device: |
| 266 | + hook.add_other_hook(other_hook) |
| 267 | + |
| 268 | + self.model_hooks = all_hooks |
| 269 | + self._auto_offload_enabled = True |
| 270 | + self._auto_offload_device = device |
| 271 | + |
| 272 | + def disable_auto_cpu_offload(self): |
| 273 | + if self.model_hooks is None: |
| 274 | + self._auto_offload_enabled = False |
| 275 | + return |
| 276 | + |
| 277 | + for hook in self.model_hooks: |
| 278 | + hook.offload() |
| 279 | + hook.remove() |
| 280 | + if self.model_hooks: |
| 281 | + clear_device_cache() |
| 282 | + self.model_hooks = None |
| 283 | + self._auto_offload_enabled = False |
| 284 | + |
| 285 | + def __repr__(self): |
| 286 | + col_widths = { |
| 287 | + "id": max(15, max(len(id) for id in self.models.keys())), |
| 288 | + "class": max(25, max(len(model.__class__.__name__) for model in self.models.values())), |
| 289 | + "device": 10, |
| 290 | + "dtype": 15, |
| 291 | + "size": 10, |
| 292 | + } |
| 293 | + |
| 294 | + # Create the header |
| 295 | + sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" |
| 296 | + dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" |
| 297 | + |
| 298 | + output = "ModelManager:\n" + sep_line |
| 299 | + |
| 300 | + # Column headers |
| 301 | + output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " |
| 302 | + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" |
| 303 | + output += dash_line |
| 304 | + |
| 305 | + # Model entries |
| 306 | + for model_id, model in self.models.items(): |
| 307 | + device = model.device |
| 308 | + dtype = model.dtype |
| 309 | + size_bytes = get_memory_footprint(model) |
| 310 | + size_gb = size_bytes / (1024**3) |
| 311 | + |
| 312 | + output += f"{model_id:<{col_widths['id']}} | {model.__class__.__name__:<{col_widths['class']}} | " |
| 313 | + output += f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" |
| 314 | + |
| 315 | + output += sep_line |
| 316 | + return output |
0 commit comments