|  | 
| 17 | 17 | 
 | 
| 18 | 18 | import torch | 
| 19 | 19 | 
 | 
|  | 20 | +from diffusers.configuration_utils import ConfigMixin | 
| 20 | 21 | from diffusers.hooks import HookRegistry, ModelHook | 
|  | 22 | +from diffusers.models.modeling_utils import ModelMixin | 
| 21 | 23 | from diffusers.training_utils import free_memory | 
| 22 | 24 | from diffusers.utils.logging import get_logger | 
| 23 | 25 | from diffusers.utils.testing_utils import CaptureLogger, torch_device | 
| @@ -61,6 +63,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | 
| 61 | 63 |         return x | 
| 62 | 64 | 
 | 
| 63 | 65 | 
 | 
|  | 66 | +class DummyModelWithMixin(ModelMixin, ConfigMixin): | 
|  | 67 | +    def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: | 
|  | 68 | +        super().__init__() | 
|  | 69 | + | 
|  | 70 | +        self.linear_1 = torch.nn.Linear(in_features, hidden_features) | 
|  | 71 | +        self.activation = torch.nn.ReLU() | 
|  | 72 | +        self.blocks = torch.nn.ModuleList( | 
|  | 73 | +            [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] | 
|  | 74 | +        ) | 
|  | 75 | +        self.linear_2 = torch.nn.Linear(hidden_features, out_features) | 
|  | 76 | + | 
|  | 77 | +    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
|  | 78 | +        x = self.linear_1(x) | 
|  | 79 | +        x = self.activation(x) | 
|  | 80 | +        for block in self.blocks: | 
|  | 81 | +            x = block(x) | 
|  | 82 | +        x = self.linear_2(x) | 
|  | 83 | +        return x | 
|  | 84 | + | 
|  | 85 | + | 
| 64 | 86 | class AddHook(ModelHook): | 
| 65 | 87 |     def __init__(self, value: int): | 
| 66 | 88 |         super().__init__() | 
| @@ -380,3 +402,71 @@ def test_invocation_order_stateful_last(self): | 
| 380 | 402 |             .replace("\n", "") | 
| 381 | 403 |         ) | 
| 382 | 404 |         self.assertEqual(output, expected_invocation_order_log) | 
|  | 405 | + | 
|  | 406 | + | 
|  | 407 | +class ModelMixinHookTests(unittest.TestCase): | 
|  | 408 | +    in_features = 4 | 
|  | 409 | +    hidden_features = 8 | 
|  | 410 | +    out_features = 4 | 
|  | 411 | +    num_layers = 2 | 
|  | 412 | + | 
|  | 413 | +    def setUp(self): | 
|  | 414 | +        params = self.get_module_parameters() | 
|  | 415 | +        self.model = DummyModelWithMixin(**params) | 
|  | 416 | +        self.model.to(torch_device) | 
|  | 417 | + | 
|  | 418 | +    def tearDown(self): | 
|  | 419 | +        super().tearDown() | 
|  | 420 | + | 
|  | 421 | +        del self.model | 
|  | 422 | +        gc.collect() | 
|  | 423 | +        free_memory() | 
|  | 424 | + | 
|  | 425 | +    def get_module_parameters(self): | 
|  | 426 | +        return { | 
|  | 427 | +            "in_features": self.in_features, | 
|  | 428 | +            "hidden_features": self.hidden_features, | 
|  | 429 | +            "out_features": self.out_features, | 
|  | 430 | +            "num_layers": self.num_layers, | 
|  | 431 | +        } | 
|  | 432 | + | 
|  | 433 | +    def get_generator(self): | 
|  | 434 | +        return torch.manual_seed(0) | 
|  | 435 | + | 
|  | 436 | +    def test_enable_disable_hook(self): | 
|  | 437 | +        registry = HookRegistry.check_if_exists_or_initialize(self.model) | 
|  | 438 | +        registry.register_hook(AddHook(1), "add_hook") | 
|  | 439 | +        registry.register_hook(MultiplyHook(2), "multiply_hook") | 
|  | 440 | + | 
|  | 441 | +        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | 
|  | 442 | +        output1 = self.model(input).mean().detach().cpu().item() | 
|  | 443 | + | 
|  | 444 | +        self.model._disable_hook("multiply_hook") | 
|  | 445 | +        output2 = self.model(input).mean().detach().cpu().item() | 
|  | 446 | + | 
|  | 447 | +        self.model._enable_hook("multiply_hook") | 
|  | 448 | +        output3 = self.model(input).mean().detach().cpu().item() | 
|  | 449 | + | 
|  | 450 | +        self.assertNotEqual(output1, output2) | 
|  | 451 | +        self.assertEqual(output1, output3) | 
|  | 452 | + | 
|  | 453 | +    def test_remove_all_hooks(self): | 
|  | 454 | +        registry = HookRegistry.check_if_exists_or_initialize(self.model) | 
|  | 455 | +        registry.register_hook(AddHook(1), "add_hook") | 
|  | 456 | +        registry.register_hook(MultiplyHook(2), "multiply_hook") | 
|  | 457 | + | 
|  | 458 | +        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | 
|  | 459 | +        output1 = self.model(input).mean().detach().cpu().item() | 
|  | 460 | + | 
|  | 461 | +        self.model._disable_hook("add_hook") | 
|  | 462 | +        self.model._disable_hook("multiply_hook") | 
|  | 463 | +        output2 = self.model(input).mean().detach().cpu().item() | 
|  | 464 | + | 
|  | 465 | +        self.model._remove_all_hooks() | 
|  | 466 | +        output3 = self.model(input).mean().detach().cpu().item() | 
|  | 467 | + | 
|  | 468 | +        for module in self.model.modules(): | 
|  | 469 | +            self.assertFalse(hasattr(module, "_diffusers_hook")) | 
|  | 470 | + | 
|  | 471 | +        self.assertNotEqual(output1, output3) | 
|  | 472 | +        self.assertEqual(output2, output3) | 
0 commit comments