Skip to content

Commit e08285e

Browse files
committed
update
1 parent 1871a69 commit e08285e

File tree

3 files changed

+149
-2
lines changed

3 files changed

+149
-2
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelHook:
3232

3333
def __init__(self):
3434
self.fn_ref: "HookFunctionReference" = None
35+
self._is_enabled = True
3536

3637
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3738
r"""
@@ -142,8 +143,10 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
142143

143144
self._module_ref = hook.initialize_hook(self._module_ref)
144145

145-
def create_new_forward(function_reference: HookFunctionReference):
146+
def create_new_forward(hook: ModelHook, function_reference: HookFunctionReference):
146147
def new_forward(module, *args, **kwargs):
148+
if not hook._is_enabled:
149+
return function_reference.forward(*args, **kwargs)
147150
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
148151
output = function_reference.forward(*args, **kwargs)
149152
return function_reference.post_forward(module, output)
@@ -163,7 +166,7 @@ def new_forward(module, *args, **kwargs):
163166
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
164167
)
165168

166-
rewritten_forward = create_new_forward(fn_ref)
169+
rewritten_forward = create_new_forward(hook, fn_ref)
167170
self._module_ref.forward = functools.update_wrapper(
168171
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
169172
)
@@ -234,3 +237,19 @@ def __repr__(self) -> str:
234237
if i < len(self._hook_order) - 1:
235238
registry_repr += "\n"
236239
return f"HookRegistry(\n{registry_repr}\n)"
240+
241+
242+
def _set_hook_state(module: torch.nn.Module, name: str, value: bool) -> None:
243+
for submodule in module.modules():
244+
if hasattr(submodule, "_diffusers_hook"):
245+
hook = submodule._diffusers_hook.get_hook(name)
246+
if hook is not None:
247+
hook._is_enabled = value
248+
249+
250+
def _remove_all_hooks(module: torch.nn.Module):
251+
for submodule in module.modules():
252+
if hasattr(submodule, "_diffusers_hook"):
253+
for hook_name in list(submodule._diffusers_hook.hooks.keys()):
254+
submodule._diffusers_hook.remove_hook(hook_name, recurse=False)
255+
del submodule._diffusers_hook

src/diffusers/models/modeling_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,44 @@ def recursive_find_attn_block(name, module):
17551755
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
17561756
return state_dict
17571757

1758+
def _enable_hook(self, name: str) -> None:
1759+
r"""
1760+
This method enables the hook with the given name on the model and all its submodules.
1761+
1762+
Args:
1763+
name (`str`):
1764+
The name of the hook to enable.
1765+
1766+
This method is not backwards compatible and may be subject to change in future versions.
1767+
"""
1768+
from ..hooks.hooks import _set_hook_state
1769+
1770+
_set_hook_state(self, name, True)
1771+
1772+
def _disable_hook(self, name: str) -> None:
1773+
r"""
1774+
This method disables the hook with the given name on the model and all its submodules.
1775+
1776+
Args:
1777+
name (`str`):
1778+
The name of the hook to disable.
1779+
1780+
This method is not backwards compatible and may be subject to change in future versions.
1781+
"""
1782+
from ..hooks.hooks import _set_hook_state
1783+
1784+
_set_hook_state(self, name, False)
1785+
1786+
def _remove_all_hooks(self) -> None:
1787+
r"""
1788+
This method removes all hooks from the model and all its submodules.
1789+
1790+
This method is not backwards compatible and may be subject to change in future versions.
1791+
"""
1792+
from ..hooks.hooks import _remove_all_hooks
1793+
1794+
_remove_all_hooks(self)
1795+
17581796

17591797
class LegacyModelMixin(ModelMixin):
17601798
r"""

tests/hooks/test_hooks.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
import torch
1919

20+
from diffusers.configuration_utils import ConfigMixin
2021
from diffusers.hooks import HookRegistry, ModelHook
22+
from diffusers.models.modeling_utils import ModelMixin
2123
from diffusers.training_utils import free_memory
2224
from diffusers.utils.logging import get_logger
2325
from diffusers.utils.testing_utils import CaptureLogger, torch_device
@@ -61,6 +63,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6163
return x
6264

6365

66+
class DummyModelWithMixin(ModelMixin, ConfigMixin):
67+
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
68+
super().__init__()
69+
70+
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
71+
self.activation = torch.nn.ReLU()
72+
self.blocks = torch.nn.ModuleList(
73+
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
74+
)
75+
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
76+
77+
def forward(self, x: torch.Tensor) -> torch.Tensor:
78+
x = self.linear_1(x)
79+
x = self.activation(x)
80+
for block in self.blocks:
81+
x = block(x)
82+
x = self.linear_2(x)
83+
return x
84+
85+
6486
class AddHook(ModelHook):
6587
def __init__(self, value: int):
6688
super().__init__()
@@ -380,3 +402,71 @@ def test_invocation_order_stateful_last(self):
380402
.replace("\n", "")
381403
)
382404
self.assertEqual(output, expected_invocation_order_log)
405+
406+
407+
class ModelMixinHookTests(unittest.TestCase):
408+
in_features = 4
409+
hidden_features = 8
410+
out_features = 4
411+
num_layers = 2
412+
413+
def setUp(self):
414+
params = self.get_module_parameters()
415+
self.model = DummyModelWithMixin(**params)
416+
self.model.to(torch_device)
417+
418+
def tearDown(self):
419+
super().tearDown()
420+
421+
del self.model
422+
gc.collect()
423+
free_memory()
424+
425+
def get_module_parameters(self):
426+
return {
427+
"in_features": self.in_features,
428+
"hidden_features": self.hidden_features,
429+
"out_features": self.out_features,
430+
"num_layers": self.num_layers,
431+
}
432+
433+
def get_generator(self):
434+
return torch.manual_seed(0)
435+
436+
def test_enable_disable_hook(self):
437+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
438+
registry.register_hook(AddHook(1), "add_hook")
439+
registry.register_hook(MultiplyHook(2), "multiply_hook")
440+
441+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
442+
output1 = self.model(input).mean().detach().cpu().item()
443+
444+
self.model._disable_hook("multiply_hook")
445+
output2 = self.model(input).mean().detach().cpu().item()
446+
447+
self.model._enable_hook("multiply_hook")
448+
output3 = self.model(input).mean().detach().cpu().item()
449+
450+
self.assertNotEqual(output1, output2)
451+
self.assertEqual(output1, output3)
452+
453+
def test_remove_all_hooks(self):
454+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
455+
registry.register_hook(AddHook(1), "add_hook")
456+
registry.register_hook(MultiplyHook(2), "multiply_hook")
457+
458+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
459+
output1 = self.model(input).mean().detach().cpu().item()
460+
461+
self.model._disable_hook("add_hook")
462+
self.model._disable_hook("multiply_hook")
463+
output2 = self.model(input).mean().detach().cpu().item()
464+
465+
self.model._remove_all_hooks()
466+
output3 = self.model(input).mean().detach().cpu().item()
467+
468+
for module in self.model.modules():
469+
self.assertFalse(hasattr(module, "_diffusers_hook"))
470+
471+
self.assertNotEqual(output1, output3)
472+
self.assertEqual(output2, output3)

0 commit comments

Comments
 (0)