|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import re | 
|  | 16 | +from typing import List, Optional, Union | 
|  | 17 | + | 
|  | 18 | +import torch | 
|  | 19 | + | 
|  | 20 | +from .hooks import HookRegistry, ModelHook | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +_TRANSFORMER_STACK_IDENTIFIERS = [ | 
|  | 24 | +    "transformer_blocks", | 
|  | 25 | +    "single_transformer_blocks", | 
|  | 26 | +    "temporal_transformer_blocks", | 
|  | 27 | +    "transformer_layers", | 
|  | 28 | +    "layers", | 
|  | 29 | +    "blocks", | 
|  | 30 | +] | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +class ModuleGroup: | 
|  | 34 | +    def __init__( | 
|  | 35 | +        self, | 
|  | 36 | +        modules: List[torch.nn.Module], | 
|  | 37 | +        offload_device: torch.device, | 
|  | 38 | +        onload_device: torch.device, | 
|  | 39 | +        offload_leader: torch.nn.Module, | 
|  | 40 | +        onload_leader: Optional[torch.nn.Module] = None, | 
|  | 41 | +    ) -> None: | 
|  | 42 | +        self.modules = modules | 
|  | 43 | +        self.offload_device = offload_device | 
|  | 44 | +        self.onload_device = onload_device | 
|  | 45 | +        self.offload_leader = offload_leader | 
|  | 46 | +        self.onload_leader = onload_leader | 
|  | 47 | + | 
|  | 48 | + | 
|  | 49 | +class GroupOffloadingHook(ModelHook): | 
|  | 50 | +    r""" | 
|  | 51 | +    A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for | 
|  | 52 | +    computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" | 
|  | 53 | +    module that is responsible for offloading. | 
|  | 54 | +
 | 
|  | 55 | +    This implementation assumes the following: | 
|  | 56 | +    - For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom | 
|  | 57 | +      user-provided regex pattern, the module that triggers its forward pass first is considered the leader. | 
|  | 58 | +    - The inputs are already on the correct device. This is expected because the hook does not modify the state of | 
|  | 59 | +      inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs | 
|  | 60 | +      not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is | 
|  | 61 | +      not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you | 
|  | 62 | +      encounter such an error. | 
|  | 63 | +    """ | 
|  | 64 | + | 
|  | 65 | +    def __init__(self, group: ModuleGroup, offload_on_init: bool = True) -> None: | 
|  | 66 | +        self.group = group | 
|  | 67 | +        self.offload_on_init = offload_on_init | 
|  | 68 | + | 
|  | 69 | +    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 70 | +        if self.offload_on_init: | 
|  | 71 | +            self.offload_(module) | 
|  | 72 | +        return module | 
|  | 73 | + | 
|  | 74 | +    def onload_(self, module: torch.nn.Module) -> None: | 
|  | 75 | +        if self.group.onload_leader is not None and self.group.onload_leader == module: | 
|  | 76 | +            for group_module in self.group.modules: | 
|  | 77 | +                group_module.to(self.group.onload_device) | 
|  | 78 | + | 
|  | 79 | +    def offload_(self, module: torch.nn.Module) -> None: | 
|  | 80 | +        if self.group.offload_leader == module: | 
|  | 81 | +            for group_module in self.group.modules: | 
|  | 82 | +                group_module.to(self.group.offload_device) | 
|  | 83 | + | 
|  | 84 | +    def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | 
|  | 85 | +        if self.group.onload_leader is None: | 
|  | 86 | +            self.group.onload_leader = module | 
|  | 87 | +        self.onload_(module) | 
|  | 88 | +        return args, kwargs | 
|  | 89 | + | 
|  | 90 | +    def post_forward(self, module: torch.nn.Module, output): | 
|  | 91 | +        self.offload_(module) | 
|  | 92 | +        return output | 
|  | 93 | + | 
|  | 94 | + | 
|  | 95 | +def apply_group_offloading( | 
|  | 96 | +    module: torch.nn.Module, | 
|  | 97 | +    offload_group_patterns: Union[str, List[str]] = "diffusers_block", | 
|  | 98 | +    num_blocks_per_group: Optional[int] = None, | 
|  | 99 | +    offload_device: torch.device = torch.device("cpu"), | 
|  | 100 | +    onload_device: torch.device = torch.device("cuda"), | 
|  | 101 | +    force_offload: bool = True, | 
|  | 102 | +) -> None: | 
|  | 103 | +    if offload_group_patterns == "diffusers_block": | 
|  | 104 | +        _apply_group_offloading_diffusers_block( | 
|  | 105 | +            module, num_blocks_per_group, offload_device, onload_device, force_offload | 
|  | 106 | +        ) | 
|  | 107 | +    else: | 
|  | 108 | +        _apply_group_offloading_group_patterns( | 
|  | 109 | +            module, offload_group_patterns, offload_device, onload_device, force_offload | 
|  | 110 | +        ) | 
|  | 111 | + | 
|  | 112 | + | 
|  | 113 | +def _apply_group_offloading_diffusers_block( | 
|  | 114 | +    module: torch.nn.Module, | 
|  | 115 | +    num_blocks_per_group: int, | 
|  | 116 | +    offload_device: torch.device, | 
|  | 117 | +    onload_device: torch.device, | 
|  | 118 | +    force_offload: bool, | 
|  | 119 | +) -> None: | 
|  | 120 | +    if num_blocks_per_group is None: | 
|  | 121 | +        raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") | 
|  | 122 | + | 
|  | 123 | +    for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS: | 
|  | 124 | +        if not hasattr(module, transformer_stack_identifier) or not isinstance( | 
|  | 125 | +            getattr(module, transformer_stack_identifier), torch.nn.ModuleList | 
|  | 126 | +        ): | 
|  | 127 | +            continue | 
|  | 128 | + | 
|  | 129 | +        transformer_stack = getattr(module, transformer_stack_identifier) | 
|  | 130 | +        num_blocks = len(transformer_stack) | 
|  | 131 | + | 
|  | 132 | +        for i in range(0, num_blocks, num_blocks_per_group): | 
|  | 133 | +            blocks = transformer_stack[i : i + num_blocks_per_group] | 
|  | 134 | +            group = ModuleGroup( | 
|  | 135 | +                blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] | 
|  | 136 | +            ) | 
|  | 137 | +            should_offload = force_offload or i == 0 | 
|  | 138 | +            _apply_group_offloading(group, should_offload) | 
|  | 139 | + | 
|  | 140 | + | 
|  | 141 | +def _apply_group_offloading_group_patterns( | 
|  | 142 | +    module: torch.nn.Module, | 
|  | 143 | +    offload_group_patterns: List[str], | 
|  | 144 | +    offload_device: torch.device, | 
|  | 145 | +    onload_device: torch.device, | 
|  | 146 | +    force_offload: bool, | 
|  | 147 | +) -> None: | 
|  | 148 | +    per_group_modules = [] | 
|  | 149 | +    for i, offload_group_pattern in enumerate(offload_group_patterns): | 
|  | 150 | +        group_modules = [] | 
|  | 151 | +        group_module_names = [] | 
|  | 152 | +        for name, module in module.named_modules(): | 
|  | 153 | +            if re.search(offload_group_pattern, name) is not None: | 
|  | 154 | +                group_modules.append(module) | 
|  | 155 | +                group_module_names.append(name) | 
|  | 156 | +        per_group_modules.append( | 
|  | 157 | +            { | 
|  | 158 | +                "modules": group_modules, | 
|  | 159 | +                "module_names": group_module_names, | 
|  | 160 | +            } | 
|  | 161 | +        ) | 
|  | 162 | + | 
|  | 163 | +    # Check if there are any overlapping modules between groups | 
|  | 164 | +    for i, group in enumerate(per_group_modules): | 
|  | 165 | +        for j, other_group in enumerate(per_group_modules): | 
|  | 166 | +            if j <= i: | 
|  | 167 | +                continue | 
|  | 168 | +            if any(module_name in group["module_names"] for module_name in other_group["module_names"]): | 
|  | 169 | +                raise ValueError( | 
|  | 170 | +                    f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive." | 
|  | 171 | +                ) | 
|  | 172 | + | 
|  | 173 | +    # Apply offloading to each group | 
|  | 174 | +    for group in per_group_modules: | 
|  | 175 | +        # TODO: handle offload leader correctly | 
|  | 176 | +        group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1]) | 
|  | 177 | +        _apply_group_offloading(group, force_offload) | 
|  | 178 | + | 
|  | 179 | + | 
|  | 180 | +def _apply_group_offloading(group: ModuleGroup, offload_on_init) -> None: | 
|  | 181 | +    for module in group.modules: | 
|  | 182 | +        hook = GroupOffloadingHook(group, offload_on_init=offload_on_init) | 
|  | 183 | +        registry = HookRegistry.check_if_exists_or_initialize(module) | 
|  | 184 | +        registry.register_hook(hook, "group_offloading") | 
0 commit comments