|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +from collections import OrderedDict | 
|  | 16 | +from itertools import combinations | 
|  | 17 | +from typing import List, Optional, Union | 
|  | 18 | + | 
|  | 19 | +import torch | 
|  | 20 | + | 
|  | 21 | +from ..utils import ( | 
|  | 22 | +    is_accelerate_available, | 
|  | 23 | +    logging, | 
|  | 24 | +) | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +if is_accelerate_available(): | 
|  | 28 | +    from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module | 
|  | 29 | +    from accelerate.state import PartialState | 
|  | 30 | +    from accelerate.utils import send_to_device | 
|  | 31 | +    from accelerate.utils.memory import clear_device_cache | 
|  | 32 | +    from accelerate.utils.modeling import convert_file_size_to_int | 
|  | 33 | + | 
|  | 34 | +logger = logging.get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 35 | + | 
|  | 36 | + | 
|  | 37 | +# YiYi Notes: copied from modeling_utils.py (decide later where to put this) | 
|  | 38 | +def get_memory_footprint(self, return_buffers=True): | 
|  | 39 | +    r""" | 
|  | 40 | +    Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to | 
|  | 41 | +    benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch | 
|  | 42 | +    discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 | 
|  | 43 | +
 | 
|  | 44 | +    Arguments: | 
|  | 45 | +        return_buffers (`bool`, *optional*, defaults to `True`): | 
|  | 46 | +            Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are | 
|  | 47 | +            tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm | 
|  | 48 | +            layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 | 
|  | 49 | +    """ | 
|  | 50 | +    mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) | 
|  | 51 | +    if return_buffers: | 
|  | 52 | +        mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) | 
|  | 53 | +        mem = mem + mem_bufs | 
|  | 54 | +    return mem | 
|  | 55 | + | 
|  | 56 | + | 
|  | 57 | +class CustomOffloadHook(ModelHook): | 
|  | 58 | +    """ | 
|  | 59 | +    A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are | 
|  | 60 | +    on the given device. Optionally offloads other models to the CPU before the forward pass is called. | 
|  | 61 | +
 | 
|  | 62 | +    Args: | 
|  | 63 | +        execution_device(`str`, `int` or `torch.device`, *optional*): | 
|  | 64 | +            The device on which the model should be executed. Will default to the MPS device if it's available, then | 
|  | 65 | +            GPU 0 if there is a GPU, and finally to the CPU. | 
|  | 66 | +    """ | 
|  | 67 | + | 
|  | 68 | +    def __init__( | 
|  | 69 | +        self, | 
|  | 70 | +        execution_device: Optional[Union[str, int, torch.device]] = None, | 
|  | 71 | +        other_hooks: Optional[List["UserCustomOffloadHook"]] = None, | 
|  | 72 | +        offload_strategy: Optional["AutoOffloadStrategy"] = None, | 
|  | 73 | +    ): | 
|  | 74 | +        self.execution_device = execution_device if execution_device is not None else PartialState().default_device | 
|  | 75 | +        self.other_hooks = other_hooks | 
|  | 76 | +        self.offload_strategy = offload_strategy | 
|  | 77 | +        self.model_id = None | 
|  | 78 | + | 
|  | 79 | +    def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): | 
|  | 80 | +        self.offload_strategy = offload_strategy | 
|  | 81 | + | 
|  | 82 | +    def add_other_hook(self, hook: "UserCustomOffloadHook"): | 
|  | 83 | +        """ | 
|  | 84 | +        Add a hook to the list of hooks to consider for offloading. | 
|  | 85 | +        """ | 
|  | 86 | +        if self.other_hooks is None: | 
|  | 87 | +            self.other_hooks = [] | 
|  | 88 | +        self.other_hooks.append(hook) | 
|  | 89 | + | 
|  | 90 | +    def init_hook(self, module): | 
|  | 91 | +        return module.to("cpu") | 
|  | 92 | + | 
|  | 93 | +    def pre_forward(self, module, *args, **kwargs): | 
|  | 94 | +        if module.device != self.execution_device: | 
|  | 95 | +            if self.other_hooks is not None: | 
|  | 96 | +                hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] | 
|  | 97 | +                # offload all other hooks | 
|  | 98 | +                import time | 
|  | 99 | + | 
|  | 100 | +                # YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later) | 
|  | 101 | +                start_time = time.perf_counter() | 
|  | 102 | +                if self.offload_strategy is not None: | 
|  | 103 | +                    hooks_to_offload = self.offload_strategy( | 
|  | 104 | +                        hooks=hooks_to_offload, | 
|  | 105 | +                        model_id=self.model_id, | 
|  | 106 | +                        model=module, | 
|  | 107 | +                        execution_device=self.execution_device, | 
|  | 108 | +                    ) | 
|  | 109 | +                end_time = time.perf_counter() | 
|  | 110 | +                logger.info( | 
|  | 111 | +                    f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" | 
|  | 112 | +                ) | 
|  | 113 | + | 
|  | 114 | +                for hook in hooks_to_offload: | 
|  | 115 | +                    logger.info( | 
|  | 116 | +                        f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" | 
|  | 117 | +                    ) | 
|  | 118 | +                    hook.offload() | 
|  | 119 | + | 
|  | 120 | +                if hooks_to_offload: | 
|  | 121 | +                    clear_device_cache() | 
|  | 122 | +            module.to(self.execution_device) | 
|  | 123 | +        return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) | 
|  | 124 | + | 
|  | 125 | + | 
|  | 126 | +class UserCustomOffloadHook: | 
|  | 127 | +    """ | 
|  | 128 | +    A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of | 
|  | 129 | +    the hook or remove it entirely. | 
|  | 130 | +    """ | 
|  | 131 | + | 
|  | 132 | +    def __init__(self, model_id, model, hook): | 
|  | 133 | +        self.model_id = model_id | 
|  | 134 | +        self.model = model | 
|  | 135 | +        self.hook = hook | 
|  | 136 | + | 
|  | 137 | +    def offload(self): | 
|  | 138 | +        self.hook.init_hook(self.model) | 
|  | 139 | + | 
|  | 140 | +    def attach(self): | 
|  | 141 | +        add_hook_to_module(self.model, self.hook) | 
|  | 142 | +        self.hook.model_id = self.model_id | 
|  | 143 | + | 
|  | 144 | +    def remove(self): | 
|  | 145 | +        remove_hook_from_module(self.model) | 
|  | 146 | +        self.hook.model_id = None | 
|  | 147 | + | 
|  | 148 | +    def add_other_hook(self, hook: "UserCustomOffloadHook"): | 
|  | 149 | +        self.hook.add_other_hook(hook) | 
|  | 150 | + | 
|  | 151 | + | 
|  | 152 | +def custom_offload_with_hook( | 
|  | 153 | +    model_id: str, | 
|  | 154 | +    model: torch.nn.Module, | 
|  | 155 | +    execution_device: Union[str, int, torch.device] = None, | 
|  | 156 | +    offload_strategy: Optional["AutoOffloadStrategy"] = None, | 
|  | 157 | +): | 
|  | 158 | +    hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) | 
|  | 159 | +    user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) | 
|  | 160 | +    user_hook.attach() | 
|  | 161 | +    return user_hook | 
|  | 162 | + | 
|  | 163 | + | 
|  | 164 | +class AutoOffloadStrategy: | 
|  | 165 | +    """ | 
|  | 166 | +    Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on | 
|  | 167 | +    the available memory on the device. | 
|  | 168 | +    """ | 
|  | 169 | + | 
|  | 170 | +    def __init__(self, memory_reserve_margin="3GB"): | 
|  | 171 | +        self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) | 
|  | 172 | + | 
|  | 173 | +    def __call__(self, hooks, model_id, model, execution_device): | 
|  | 174 | +        if len(hooks) == 0: | 
|  | 175 | +            return [] | 
|  | 176 | + | 
|  | 177 | +        current_module_size = get_memory_footprint(model) | 
|  | 178 | + | 
|  | 179 | +        mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] | 
|  | 180 | +        mem_on_device = mem_on_device - self.memory_reserve_margin | 
|  | 181 | +        if current_module_size < mem_on_device: | 
|  | 182 | +            return [] | 
|  | 183 | + | 
|  | 184 | +        min_memory_offload = current_module_size - mem_on_device | 
|  | 185 | +        logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") | 
|  | 186 | + | 
|  | 187 | +        # exlucde models that's not currently loaded on the device | 
|  | 188 | +        module_sizes = dict( | 
|  | 189 | +            sorted( | 
|  | 190 | +                {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), | 
|  | 191 | +                key=lambda x: x[1], | 
|  | 192 | +                reverse=True, | 
|  | 193 | +            ) | 
|  | 194 | +        ) | 
|  | 195 | + | 
|  | 196 | +        def search_best_candidate(module_sizes, min_memory_offload): | 
|  | 197 | +            """ | 
|  | 198 | +            search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a | 
|  | 199 | +            minimum memory offload size. the combination of models should add up to the smallest modulesize that is | 
|  | 200 | +            larger than `min_memory_offload` | 
|  | 201 | +            """ | 
|  | 202 | +            model_ids = list(module_sizes.keys()) | 
|  | 203 | +            best_candidate = None | 
|  | 204 | +            best_size = float("inf") | 
|  | 205 | +            for r in range(1, len(model_ids) + 1): | 
|  | 206 | +                for candidate_model_ids in combinations(model_ids, r): | 
|  | 207 | +                    candidate_size = sum( | 
|  | 208 | +                        module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids | 
|  | 209 | +                    ) | 
|  | 210 | +                    if candidate_size < min_memory_offload: | 
|  | 211 | +                        continue | 
|  | 212 | +                    else: | 
|  | 213 | +                        if best_candidate is None or candidate_size < best_size: | 
|  | 214 | +                            best_candidate = candidate_model_ids | 
|  | 215 | +                            best_size = candidate_size | 
|  | 216 | + | 
|  | 217 | +            return best_candidate | 
|  | 218 | + | 
|  | 219 | +        best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) | 
|  | 220 | + | 
|  | 221 | +        if best_offload_model_ids is None: | 
|  | 222 | +            # if no combination is found, meaning that we cannot meet the memory requirement, offload all models | 
|  | 223 | +            logger.warning("no combination of models to offload to cpu is found, offloading all models") | 
|  | 224 | +            hooks_to_offload = hooks | 
|  | 225 | +        else: | 
|  | 226 | +            hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] | 
|  | 227 | + | 
|  | 228 | +        return hooks_to_offload | 
|  | 229 | + | 
|  | 230 | + | 
|  | 231 | +class ComponentsManager: | 
|  | 232 | +    def __init__(self): | 
|  | 233 | +        self.components = OrderedDict() | 
|  | 234 | +        self.model_hooks = None | 
|  | 235 | +        self._auto_offload_enabled = False | 
|  | 236 | + | 
|  | 237 | +    def add(self, name, component): | 
|  | 238 | +        if name not in self.components: | 
|  | 239 | +            self.components[name] = component | 
|  | 240 | +            if self._auto_offload_enabled: | 
|  | 241 | +                self.enable_auto_cpu_offload(self._auto_offload_device) | 
|  | 242 | + | 
|  | 243 | +    def remove(self, name): | 
|  | 244 | +        self.components.pop(name) | 
|  | 245 | +        if self._auto_offload_enabled: | 
|  | 246 | +            self.enable_auto_cpu_offload(self._auto_offload_device) | 
|  | 247 | + | 
|  | 248 | +    def get(self, names: Union[str, List[str]]): | 
|  | 249 | +        if isinstance(names, str): | 
|  | 250 | +            if names not in self.components: | 
|  | 251 | +                raise ValueError(f"Component '{names}' not found in ComponentsManager") | 
|  | 252 | +            return self.components[names] | 
|  | 253 | +        elif isinstance(names, list): | 
|  | 254 | +            return {n: self.components[n] for n in names} | 
|  | 255 | +        else: | 
|  | 256 | +            raise ValueError(f"Invalid type for names: {type(names)}") | 
|  | 257 | + | 
|  | 258 | +    def enable_auto_cpu_offload(self, device, memory_reserve_margin="3GB"): | 
|  | 259 | +        for name, component in self.components.items(): | 
|  | 260 | +            if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): | 
|  | 261 | +                remove_hook_from_module(component, recurse=True) | 
|  | 262 | + | 
|  | 263 | +        self.disable_auto_cpu_offload() | 
|  | 264 | +        offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) | 
|  | 265 | +        device = torch.device(device) | 
|  | 266 | +        if device.index is None: | 
|  | 267 | +            device = torch.device(f"{device.type}:{0}") | 
|  | 268 | +        all_hooks = [] | 
|  | 269 | +        for name, component in self.components.items(): | 
|  | 270 | +            if isinstance(component, torch.nn.Module): | 
|  | 271 | +                hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) | 
|  | 272 | +                all_hooks.append(hook) | 
|  | 273 | + | 
|  | 274 | +        for hook in all_hooks: | 
|  | 275 | +            other_hooks = [h for h in all_hooks if h is not hook] | 
|  | 276 | +            for other_hook in other_hooks: | 
|  | 277 | +                if other_hook.hook.execution_device == hook.hook.execution_device: | 
|  | 278 | +                    hook.add_other_hook(other_hook) | 
|  | 279 | + | 
|  | 280 | +        self.model_hooks = all_hooks | 
|  | 281 | +        self._auto_offload_enabled = True | 
|  | 282 | +        self._auto_offload_device = device | 
|  | 283 | + | 
|  | 284 | +    def disable_auto_cpu_offload(self): | 
|  | 285 | +        if self.model_hooks is None: | 
|  | 286 | +            self._auto_offload_enabled = False | 
|  | 287 | +            return | 
|  | 288 | + | 
|  | 289 | +        for hook in self.model_hooks: | 
|  | 290 | +            hook.offload() | 
|  | 291 | +            hook.remove() | 
|  | 292 | +        if self.model_hooks: | 
|  | 293 | +            clear_device_cache() | 
|  | 294 | +        self.model_hooks = None | 
|  | 295 | +        self._auto_offload_enabled = False | 
|  | 296 | + | 
|  | 297 | +    def __repr__(self): | 
|  | 298 | +        col_widths = { | 
|  | 299 | +            "id": max(15, max(len(id) for id in self.components.keys())), | 
|  | 300 | +            "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), | 
|  | 301 | +            "device": 10, | 
|  | 302 | +            "dtype": 15, | 
|  | 303 | +            "size": 10, | 
|  | 304 | +        } | 
|  | 305 | + | 
|  | 306 | +        # Create the header lines | 
|  | 307 | +        sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" | 
|  | 308 | +        dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" | 
|  | 309 | + | 
|  | 310 | +        output = "Components:\n" + sep_line | 
|  | 311 | + | 
|  | 312 | +        # Separate components into models and others | 
|  | 313 | +        models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} | 
|  | 314 | +        others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} | 
|  | 315 | + | 
|  | 316 | +        # Models section | 
|  | 317 | +        if models: | 
|  | 318 | +            output += "Models:\n" + dash_line | 
|  | 319 | +            # Column headers | 
|  | 320 | +            output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " | 
|  | 321 | +            output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n" | 
|  | 322 | +            output += dash_line | 
|  | 323 | + | 
|  | 324 | +            # Model entries | 
|  | 325 | +            for name, component in models.items(): | 
|  | 326 | +                device = component.device | 
|  | 327 | +                dtype = component.dtype | 
|  | 328 | +                size_bytes = get_memory_footprint(component) | 
|  | 329 | +                size_gb = size_bytes / (1024**3) | 
|  | 330 | + | 
|  | 331 | +                output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | " | 
|  | 332 | +                output += ( | 
|  | 333 | +                    f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n" | 
|  | 334 | +                ) | 
|  | 335 | +            output += dash_line | 
|  | 336 | + | 
|  | 337 | +        # Other components section | 
|  | 338 | +        if others: | 
|  | 339 | +            if models:  # Add extra newline if we had models section | 
|  | 340 | +                output += "\n" | 
|  | 341 | +            output += "Other Components:\n" + dash_line | 
|  | 342 | +            # Column headers for other components | 
|  | 343 | +            output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" | 
|  | 344 | +            output += dash_line | 
|  | 345 | + | 
|  | 346 | +            # Other component entries | 
|  | 347 | +            for name, component in others.items(): | 
|  | 348 | +                output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" | 
|  | 349 | +            output += dash_line | 
|  | 350 | + | 
|  | 351 | +        return output | 
|  | 352 | + | 
|  | 353 | +    def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs): | 
|  | 354 | +        from ..pipelines.pipeline_utils import DiffusionPipeline | 
|  | 355 | + | 
|  | 356 | +        pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) | 
|  | 357 | +        for name, component in pipe.components.items(): | 
|  | 358 | +            if name not in self.components and component is not None: | 
|  | 359 | +                self.add(name, component) | 
|  | 360 | +            elif name in self.components: | 
|  | 361 | +                logger.warning( | 
|  | 362 | +                    f"Component '{name}' already exists in ComponentsManager and will not be added. To add it, either:\n" | 
|  | 363 | +                    f"1. remove the existing component with remove('{name}')\n" | 
|  | 364 | +                    f"2. Use a different name: add('{name}_2', component)" | 
|  | 365 | +                ) | 
0 commit comments