|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import functools | 
|  | 16 | +from typing import Any, Dict, Optional, Tuple | 
|  | 17 | + | 
|  | 18 | +import torch | 
|  | 19 | + | 
|  | 20 | +from ..utils.logging import get_logger | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +class ModelHook: | 
|  | 27 | +    r""" | 
|  | 28 | +    A hook that contains callbacks to be executed just before and after the forward method of a model. | 
|  | 29 | +    """ | 
|  | 30 | + | 
|  | 31 | +    _is_stateful = False | 
|  | 32 | + | 
|  | 33 | +    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 34 | +        r""" | 
|  | 35 | +        Hook that is executed when a model is initialized. | 
|  | 36 | +
 | 
|  | 37 | +        Args: | 
|  | 38 | +            module (`torch.nn.Module`): | 
|  | 39 | +                The module attached to this hook. | 
|  | 40 | +        """ | 
|  | 41 | +        return module | 
|  | 42 | + | 
|  | 43 | +    def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 44 | +        r""" | 
|  | 45 | +        Hook that is executed when a model is deinitalized. | 
|  | 46 | +
 | 
|  | 47 | +        Args: | 
|  | 48 | +            module (`torch.nn.Module`): | 
|  | 49 | +                The module attached to this hook. | 
|  | 50 | +        """ | 
|  | 51 | +        module.forward = module._old_forward | 
|  | 52 | +        del module._old_forward | 
|  | 53 | +        return module | 
|  | 54 | + | 
|  | 55 | +    def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | 
|  | 56 | +        r""" | 
|  | 57 | +        Hook that is executed just before the forward method of the model. | 
|  | 58 | +
 | 
|  | 59 | +        Args: | 
|  | 60 | +            module (`torch.nn.Module`): | 
|  | 61 | +                The module whose forward pass will be executed just after this event. | 
|  | 62 | +            args (`Tuple[Any]`): | 
|  | 63 | +                The positional arguments passed to the module. | 
|  | 64 | +            kwargs (`Dict[Str, Any]`): | 
|  | 65 | +                The keyword arguments passed to the module. | 
|  | 66 | +        Returns: | 
|  | 67 | +            `Tuple[Tuple[Any], Dict[Str, Any]]`: | 
|  | 68 | +                A tuple with the treated `args` and `kwargs`. | 
|  | 69 | +        """ | 
|  | 70 | +        return args, kwargs | 
|  | 71 | + | 
|  | 72 | +    def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | 
|  | 73 | +        r""" | 
|  | 74 | +        Hook that is executed just after the forward method of the model. | 
|  | 75 | +
 | 
|  | 76 | +        Args: | 
|  | 77 | +            module (`torch.nn.Module`): | 
|  | 78 | +                The module whose forward pass been executed just before this event. | 
|  | 79 | +            output (`Any`): | 
|  | 80 | +                The output of the module. | 
|  | 81 | +        Returns: | 
|  | 82 | +            `Any`: The processed `output`. | 
|  | 83 | +        """ | 
|  | 84 | +        return output | 
|  | 85 | + | 
|  | 86 | +    def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 87 | +        r""" | 
|  | 88 | +        Hook that is executed when the hook is detached from a module. | 
|  | 89 | +
 | 
|  | 90 | +        Args: | 
|  | 91 | +            module (`torch.nn.Module`): | 
|  | 92 | +                The module detached from this hook. | 
|  | 93 | +        """ | 
|  | 94 | +        return module | 
|  | 95 | + | 
|  | 96 | +    def reset_state(self, module: torch.nn.Module): | 
|  | 97 | +        if self._is_stateful: | 
|  | 98 | +            raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") | 
|  | 99 | +        return module | 
|  | 100 | + | 
|  | 101 | + | 
|  | 102 | +class HookRegistry: | 
|  | 103 | +    def __init__(self, module_ref: torch.nn.Module) -> None: | 
|  | 104 | +        super().__init__() | 
|  | 105 | + | 
|  | 106 | +        self.hooks: Dict[str, ModelHook] = {} | 
|  | 107 | + | 
|  | 108 | +        self._module_ref = module_ref | 
|  | 109 | +        self._hook_order = [] | 
|  | 110 | + | 
|  | 111 | +    def register_hook(self, hook: ModelHook, name: str) -> None: | 
|  | 112 | +        if name in self.hooks.keys(): | 
|  | 113 | +            logger.warning(f"Hook with name {name} already exists, replacing it.") | 
|  | 114 | + | 
|  | 115 | +        if hasattr(self._module_ref, "_old_forward"): | 
|  | 116 | +            old_forward = self._module_ref._old_forward | 
|  | 117 | +        else: | 
|  | 118 | +            old_forward = self._module_ref.forward | 
|  | 119 | +            self._module_ref._old_forward = self._module_ref.forward | 
|  | 120 | + | 
|  | 121 | +        self._module_ref = hook.initialize_hook(self._module_ref) | 
|  | 122 | + | 
|  | 123 | +        if hasattr(hook, "new_forward"): | 
|  | 124 | +            rewritten_forward = hook.new_forward | 
|  | 125 | + | 
|  | 126 | +            def new_forward(module, *args, **kwargs): | 
|  | 127 | +                args, kwargs = hook.pre_forward(module, *args, **kwargs) | 
|  | 128 | +                output = rewritten_forward(module, *args, **kwargs) | 
|  | 129 | +                return hook.post_forward(module, output) | 
|  | 130 | +        else: | 
|  | 131 | + | 
|  | 132 | +            def new_forward(module, *args, **kwargs): | 
|  | 133 | +                args, kwargs = hook.pre_forward(module, *args, **kwargs) | 
|  | 134 | +                output = old_forward(*args, **kwargs) | 
|  | 135 | +                return hook.post_forward(module, output) | 
|  | 136 | + | 
|  | 137 | +        self._module_ref.forward = functools.update_wrapper( | 
|  | 138 | +            functools.partial(new_forward, self._module_ref), old_forward | 
|  | 139 | +        ) | 
|  | 140 | + | 
|  | 141 | +        self.hooks[name] = hook | 
|  | 142 | +        self._hook_order.append(name) | 
|  | 143 | + | 
|  | 144 | +    def get_hook(self, name: str) -> Optional[ModelHook]: | 
|  | 145 | +        if name not in self.hooks.keys(): | 
|  | 146 | +            return None | 
|  | 147 | +        return self.hooks[name] | 
|  | 148 | + | 
|  | 149 | +    def remove_hook(self, name: str, recurse: bool = True) -> None: | 
|  | 150 | +        if name in self.hooks.keys(): | 
|  | 151 | +            hook = self.hooks[name] | 
|  | 152 | +            self._module_ref = hook.deinitalize_hook(self._module_ref) | 
|  | 153 | +            del self.hooks[name] | 
|  | 154 | +            self._hook_order.remove(name) | 
|  | 155 | + | 
|  | 156 | +        if recurse: | 
|  | 157 | +            for module_name, module in self._module_ref.named_modules(): | 
|  | 158 | +                if module_name == "": | 
|  | 159 | +                    continue | 
|  | 160 | +                if hasattr(module, "_diffusers_hook"): | 
|  | 161 | +                    module._diffusers_hook.remove_hook(name, recurse=False) | 
|  | 162 | + | 
|  | 163 | +    def reset_stateful_hooks(self, recurse: bool = True) -> None: | 
|  | 164 | +        for hook_name in self._hook_order: | 
|  | 165 | +            hook = self.hooks[hook_name] | 
|  | 166 | +            if hook._is_stateful: | 
|  | 167 | +                hook.reset_state(self._module_ref) | 
|  | 168 | + | 
|  | 169 | +        if recurse: | 
|  | 170 | +            for module_name, module in self._module_ref.named_modules(): | 
|  | 171 | +                if module_name == "": | 
|  | 172 | +                    continue | 
|  | 173 | +                if hasattr(module, "_diffusers_hook"): | 
|  | 174 | +                    module._diffusers_hook.reset_stateful_hooks(recurse=False) | 
|  | 175 | + | 
|  | 176 | +    @classmethod | 
|  | 177 | +    def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": | 
|  | 178 | +        if not hasattr(module, "_diffusers_hook"): | 
|  | 179 | +            module._diffusers_hook = cls(module) | 
|  | 180 | +        return module._diffusers_hook | 
|  | 181 | + | 
|  | 182 | +    def __repr__(self) -> str: | 
|  | 183 | +        hook_repr = "" | 
|  | 184 | +        for i, hook_name in enumerate(self._hook_order): | 
|  | 185 | +            hook_repr += f"  ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" | 
|  | 186 | +            if i < len(self._hook_order) - 1: | 
|  | 187 | +                hook_repr += "\n" | 
|  | 188 | +        return f"HookRegistry(\n{hook_repr}\n)" | 
0 commit comments