|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import functools | 
|  | 16 | +from typing import Any, Callable, Dict, Tuple | 
|  | 17 | + | 
|  | 18 | +import torch | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py | 
|  | 22 | +class ModelHook: | 
|  | 23 | +    r""" | 
|  | 24 | +    A hook that contains callbacks to be executed just before and after the forward method of a model. The difference | 
|  | 25 | +    with PyTorch existing hooks is that they get passed along the kwargs. | 
|  | 26 | +    """ | 
|  | 27 | + | 
|  | 28 | +    def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 29 | +        r""" | 
|  | 30 | +        Hook that is executed when a model is initialized. | 
|  | 31 | +        Args: | 
|  | 32 | +            module (`torch.nn.Module`): | 
|  | 33 | +                The module attached to this hook. | 
|  | 34 | +        """ | 
|  | 35 | +        return module | 
|  | 36 | + | 
|  | 37 | +    def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | 
|  | 38 | +        r""" | 
|  | 39 | +        Hook that is executed just before the forward method of the model. | 
|  | 40 | +        Args: | 
|  | 41 | +            module (`torch.nn.Module`): | 
|  | 42 | +                The module whose forward pass will be executed just after this event. | 
|  | 43 | +            args (`Tuple[Any]`): | 
|  | 44 | +                The positional arguments passed to the module. | 
|  | 45 | +            kwargs (`Dict[Str, Any]`): | 
|  | 46 | +                The keyword arguments passed to the module. | 
|  | 47 | +        Returns: | 
|  | 48 | +            `Tuple[Tuple[Any], Dict[Str, Any]]`: | 
|  | 49 | +                A tuple with the treated `args` and `kwargs`. | 
|  | 50 | +        """ | 
|  | 51 | +        return args, kwargs | 
|  | 52 | + | 
|  | 53 | +    def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | 
|  | 54 | +        r""" | 
|  | 55 | +        Hook that is executed just after the forward method of the model. | 
|  | 56 | +        Args: | 
|  | 57 | +            module (`torch.nn.Module`): | 
|  | 58 | +                The module whose forward pass been executed just before this event. | 
|  | 59 | +            output (`Any`): | 
|  | 60 | +                The output of the module. | 
|  | 61 | +        Returns: | 
|  | 62 | +            `Any`: The processed `output`. | 
|  | 63 | +        """ | 
|  | 64 | +        return output | 
|  | 65 | + | 
|  | 66 | +    def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 67 | +        r""" | 
|  | 68 | +        Hook that is executed when the hook is detached from a module. | 
|  | 69 | +        Args: | 
|  | 70 | +            module (`torch.nn.Module`): | 
|  | 71 | +                The module detached from this hook. | 
|  | 72 | +        """ | 
|  | 73 | +        return module | 
|  | 74 | + | 
|  | 75 | +    def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 76 | +        return module | 
|  | 77 | + | 
|  | 78 | + | 
|  | 79 | +class SequentialHook(ModelHook): | 
|  | 80 | +    r"""A hook that can contain several hooks and iterates through them at each event.""" | 
|  | 81 | + | 
|  | 82 | +    def __init__(self, *hooks): | 
|  | 83 | +        self.hooks = hooks | 
|  | 84 | + | 
|  | 85 | +    def init_hook(self, module): | 
|  | 86 | +        for hook in self.hooks: | 
|  | 87 | +            module = hook.init_hook(module) | 
|  | 88 | +        return module | 
|  | 89 | + | 
|  | 90 | +    def pre_forward(self, module, *args, **kwargs): | 
|  | 91 | +        for hook in self.hooks: | 
|  | 92 | +            args, kwargs = hook.pre_forward(module, *args, **kwargs) | 
|  | 93 | +        return args, kwargs | 
|  | 94 | + | 
|  | 95 | +    def post_forward(self, module, output): | 
|  | 96 | +        for hook in self.hooks: | 
|  | 97 | +            output = hook.post_forward(module, output) | 
|  | 98 | +        return output | 
|  | 99 | + | 
|  | 100 | +    def detach_hook(self, module): | 
|  | 101 | +        for hook in self.hooks: | 
|  | 102 | +            module = hook.detach_hook(module) | 
|  | 103 | +        return module | 
|  | 104 | + | 
|  | 105 | +    def reset_state(self, module): | 
|  | 106 | +        for hook in self.hooks: | 
|  | 107 | +            module = hook.reset_state(module) | 
|  | 108 | +        return module | 
|  | 109 | + | 
|  | 110 | + | 
|  | 111 | +class FasterCacheHook(ModelHook): | 
|  | 112 | +    def __init__( | 
|  | 113 | +        self, | 
|  | 114 | +        skip_callback: Callable[[torch.nn.Module], bool], | 
|  | 115 | +    ) -> None: | 
|  | 116 | +        super().__init__() | 
|  | 117 | + | 
|  | 118 | +        self.skip_callback = skip_callback | 
|  | 119 | + | 
|  | 120 | +        self.cache = None | 
|  | 121 | +        self._iteration = 0 | 
|  | 122 | + | 
|  | 123 | +    def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: | 
|  | 124 | +        args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) | 
|  | 125 | + | 
|  | 126 | +        if self.cache is not None and self.skip_callback(module): | 
|  | 127 | +            output = self.cache | 
|  | 128 | +        else: | 
|  | 129 | +            output = module._old_forward(*args, **kwargs) | 
|  | 130 | + | 
|  | 131 | +        return module._diffusers_hook.post_forward(module, output) | 
|  | 132 | + | 
|  | 133 | +    def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | 
|  | 134 | +        self.cache = output | 
|  | 135 | +        return output | 
|  | 136 | + | 
|  | 137 | +    def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 138 | +        self.cache = None | 
|  | 139 | +        self._iteration = 0 | 
|  | 140 | +        return module | 
|  | 141 | + | 
|  | 142 | + | 
|  | 143 | +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): | 
|  | 144 | +    r""" | 
|  | 145 | +    Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove | 
|  | 146 | +    this behavior and restore the original `forward` method, use `remove_hook_from_module`. | 
|  | 147 | +    <Tip warning={true}> | 
|  | 148 | +    If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks | 
|  | 149 | +    together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. | 
|  | 150 | +    </Tip> | 
|  | 151 | +    Args: | 
|  | 152 | +        module (`torch.nn.Module`): | 
|  | 153 | +            The module to attach a hook to. | 
|  | 154 | +        hook (`ModelHook`): | 
|  | 155 | +            The hook to attach. | 
|  | 156 | +        append (`bool`, *optional*, defaults to `False`): | 
|  | 157 | +            Whether the hook should be chained with an existing one (if module already contains a hook) or not. | 
|  | 158 | +    Returns: | 
|  | 159 | +        `torch.nn.Module`: | 
|  | 160 | +            The same module, with the hook attached (the module is modified in place, so the result can be discarded). | 
|  | 161 | +    """ | 
|  | 162 | +    original_hook = hook | 
|  | 163 | + | 
|  | 164 | +    if append and getattr(module, "_diffusers_hook", None) is not None: | 
|  | 165 | +        old_hook = module._diffusers_hook | 
|  | 166 | +        remove_hook_from_module(module) | 
|  | 167 | +        hook = SequentialHook(old_hook, hook) | 
|  | 168 | + | 
|  | 169 | +    if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): | 
|  | 170 | +        # If we already put some hook on this module, we replace it with the new one. | 
|  | 171 | +        old_forward = module._old_forward | 
|  | 172 | +    else: | 
|  | 173 | +        old_forward = module.forward | 
|  | 174 | +        module._old_forward = old_forward | 
|  | 175 | + | 
|  | 176 | +    module = hook.init_hook(module) | 
|  | 177 | +    module._diffusers_hook = hook | 
|  | 178 | + | 
|  | 179 | +    if hasattr(original_hook, "new_forward"): | 
|  | 180 | +        new_forward = original_hook.new_forward | 
|  | 181 | +    else: | 
|  | 182 | + | 
|  | 183 | +        def new_forward(module, *args, **kwargs): | 
|  | 184 | +            args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) | 
|  | 185 | +            output = module._old_forward(*args, **kwargs) | 
|  | 186 | +            return module._diffusers_hook.post_forward(module, output) | 
|  | 187 | + | 
|  | 188 | +    # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. | 
|  | 189 | +    # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 | 
|  | 190 | +    if "GraphModuleImpl" in str(type(module)): | 
|  | 191 | +        module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) | 
|  | 192 | +    else: | 
|  | 193 | +        module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) | 
|  | 194 | + | 
|  | 195 | +    return module | 
|  | 196 | + | 
|  | 197 | + | 
|  | 198 | +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: | 
|  | 199 | +    """ | 
|  | 200 | +    Removes any hook attached to a module via `add_hook_to_module`. | 
|  | 201 | +    Args: | 
|  | 202 | +        module (`torch.nn.Module`): | 
|  | 203 | +            The module to attach a hook to. | 
|  | 204 | +        recurse (`bool`, defaults to `False`): | 
|  | 205 | +            Whether to remove the hooks recursively | 
|  | 206 | +    Returns: | 
|  | 207 | +        `torch.nn.Module`: | 
|  | 208 | +            The same module, with the hook detached (the module is modified in place, so the result can be discarded). | 
|  | 209 | +    """ | 
|  | 210 | + | 
|  | 211 | +    if hasattr(module, "_diffusers_hook"): | 
|  | 212 | +        module._diffusers_hook.detach_hook(module) | 
|  | 213 | +        delattr(module, "_diffusers_hook") | 
|  | 214 | + | 
|  | 215 | +    if hasattr(module, "_old_forward"): | 
|  | 216 | +        # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. | 
|  | 217 | +        # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 | 
|  | 218 | +        if "GraphModuleImpl" in str(type(module)): | 
|  | 219 | +            module.__class__.forward = module._old_forward | 
|  | 220 | +        else: | 
|  | 221 | +            module.forward = module._old_forward | 
|  | 222 | +        delattr(module, "_old_forward") | 
|  | 223 | + | 
|  | 224 | +    if recurse: | 
|  | 225 | +        for child in module.children(): | 
|  | 226 | +            remove_hook_from_module(child, recurse) | 
|  | 227 | + | 
|  | 228 | +    return module | 
0 commit comments