|
21 | 21 | # Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py |
22 | 22 | class ModelHook: |
23 | 23 | r""" |
24 | | - A hook that contains callbacks to be executed just before and after the forward method of a model. The difference |
25 | | - with PyTorch existing hooks is that they get passed along the kwargs. |
| 24 | + A hook that contains callbacks to be executed just before and after the forward method of a model. |
26 | 25 | """ |
27 | 26 |
|
| 27 | + _is_stateful = False |
| 28 | + |
28 | 29 | def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: |
29 | 30 | r""" |
30 | 31 | Hook that is executed when a model is initialized. |
@@ -78,6 +79,10 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: |
78 | 79 | """ |
79 | 80 | return module |
80 | 81 |
|
| 82 | + def reset_state(self, module: torch.nn.Module): |
| 83 | + if self._is_stateful: |
| 84 | + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") |
| 85 | + |
81 | 86 |
|
82 | 87 | class SequentialHook(ModelHook): |
83 | 88 | r"""A hook that can contain several hooks and iterates through them at each event.""" |
@@ -105,8 +110,13 @@ def detach_hook(self, module): |
105 | 110 | module = hook.detach_hook(module) |
106 | 111 | return module |
107 | 112 |
|
| 113 | + def reset_state(self, module): |
| 114 | + for hook in self.hooks: |
| 115 | + if hook._is_stateful: |
| 116 | + hook.reset_state(module) |
| 117 | + |
108 | 118 |
|
109 | | -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): |
| 119 | +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: |
110 | 120 | r""" |
111 | 121 | Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove |
112 | 122 | this behavior and restore the original `forward` method, use `remove_hook_from_module`. |
@@ -199,3 +209,21 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t |
199 | 209 | remove_hook_from_module(child, recurse) |
200 | 210 |
|
201 | 211 | return module |
| 212 | + |
| 213 | + |
| 214 | +def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): |
| 215 | + """ |
| 216 | + Resets the state of all stateful hooks attached to a module. |
| 217 | +
|
| 218 | + Args: |
| 219 | + module (`torch.nn.Module`): |
| 220 | + The module to reset the stateful hooks from. |
| 221 | + """ |
| 222 | + if hasattr(module, "_diffusers_hook") and ( |
| 223 | + module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) |
| 224 | + ): |
| 225 | + module._diffusers_hook.reset_state(module) |
| 226 | + |
| 227 | + if recurse: |
| 228 | + for child in module.children(): |
| 229 | + reset_stateful_hooks(child, recurse) |
0 commit comments