Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import gc
import unittest

import pytest
import torch

from diffusers.hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -134,20 +134,18 @@ def post_forward(self, module, output):
return output


class HookTests(unittest.TestCase):
class TestHooks:
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2

def setUp(self):
def setup_method(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)

def tearDown(self):
super().tearDown()

def teardown_method(self):
del self.model
gc.collect()
free_memory()
Expand All @@ -171,20 +169,20 @@ def test_hook_registry(self):
registry_repr = repr(registry)
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"

self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
assert len(registry.hooks) == 2
assert registry._hook_order == ["add_hook", "multiply_hook"]
assert registry_repr == expected_repr

registry.remove_hook("add_hook")

self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
assert len(registry.hooks) == 1
assert registry._hook_order == ["multiply_hook"]

def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")

self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
assert registry.hooks["stateful_add_hook"].increment == 0

input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
Expand All @@ -194,13 +192,13 @@ def test_stateful_hook(self):
if i == 0:
output1 = result

self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
assert registry.get_hook("stateful_add_hook").increment == num_repeats

registry.reset_stateful_hooks()
output2 = self.model(input)

self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
assert registry.get_hook("stateful_add_hook").increment == 1
assert torch.allclose(output1, output2)

def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
Expand All @@ -218,40 +216,39 @@ def test_inference(self):
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()

self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
self.assertAlmostEqual(output2, output3, places=5)
assert output1 == pytest.approx(output2, abs=5e-6)
assert output1 == pytest.approx(output3, abs=5e-6)
assert output2 == pytest.approx(output3, abs=5e-6)

def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")

input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
self.assertEqual(output, 0.0)
assert output == 0.0

registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0

def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)

registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with self.assertRaises(RuntimeError) as cm:
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))

registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0

registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0

def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
Expand All @@ -278,7 +275,7 @@ def test_invocation_order_stateful_first(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
Expand All @@ -289,7 +286,7 @@ def test_invocation_order_stateful_first(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
Expand All @@ -316,7 +313,7 @@ def test_invocation_order_stateful_middle(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
Expand All @@ -327,7 +324,7 @@ def test_invocation_order_stateful_middle(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
Expand All @@ -336,7 +333,7 @@ def test_invocation_order_stateful_middle(self):
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
Expand All @@ -363,7 +360,7 @@ def test_invocation_order_stateful_last(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log

registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
Expand All @@ -374,4 +371,4 @@ def test_invocation_order_stateful_last(self):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
Loading