Skip to content

Commit 8546c9e

Browse files
committed
new_forward support
1 parent e08285e commit 8546c9e

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
146146
def create_new_forward(hook: ModelHook, function_reference: HookFunctionReference):
147147
def new_forward(module, *args, **kwargs):
148148
if not hook._is_enabled:
149-
return function_reference.forward(*args, **kwargs)
149+
return function_reference.original_forward(*args, **kwargs)
150150
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
151151
output = function_reference.forward(*args, **kwargs)
152152
return function_reference.post_forward(module, output)
@@ -158,6 +158,7 @@ def new_forward(module, *args, **kwargs):
158158
fn_ref = HookFunctionReference()
159159
fn_ref.pre_forward = hook.pre_forward
160160
fn_ref.post_forward = hook.post_forward
161+
fn_ref.original_forward = forward
161162
fn_ref.forward = forward
162163

163164
if hasattr(hook, "new_forward"):

tests/hooks/test_hooks.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,31 @@ def test_enable_disable_hook(self):
450450
self.assertNotEqual(output1, output2)
451451
self.assertEqual(output1, output3)
452452

453+
def test_enable_disable_hook_containing_new_forward(self):
454+
registry = HookRegistry.check_if_exists_or_initialize(self.model)
455+
registry.register_hook(AddHook(1), "add_hook")
456+
for block in self.model.blocks:
457+
block_registry = HookRegistry.check_if_exists_or_initialize(block)
458+
block_registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
459+
registry.register_hook(MultiplyHook(2), "multiply_hook")
460+
461+
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
462+
output1 = self.model(input).mean().detach().cpu().item()
463+
464+
self.model._disable_hook("skip_layer_hook")
465+
output2 = self.model(input).mean().detach().cpu().item()
466+
467+
self.model._disable_hook("add_hook")
468+
output3 = self.model(input).mean().detach().cpu().item()
469+
470+
self.model._enable_hook("skip_layer_hook")
471+
self.model._enable_hook("add_hook")
472+
output4 = self.model(input).mean().detach().cpu().item()
473+
474+
self.assertNotEqual(output1, output2)
475+
self.assertNotEqual(output2, output3)
476+
self.assertEqual(output1, output4)
477+
453478
def test_remove_all_hooks(self):
454479
registry = HookRegistry.check_if_exists_or_initialize(self.model)
455480
registry.register_hook(AddHook(1), "add_hook")

0 commit comments

Comments
 (0)