|
34 | 34 | torch_device, |
35 | 35 | ) |
36 | 36 |
|
| 37 | +from typing import Any, Iterable, List, Optional, Sequence, Union |
37 | 38 |
|
38 | 39 | class DummyBlock(torch.nn.Module): |
39 | 40 | def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: |
@@ -216,6 +217,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
216 | 217 | x = block(x) |
217 | 218 | x = self.norm(x) |
218 | 219 | return x |
| 220 | + |
| 221 | + # Test for https://github.com/huggingface/diffusers/pull/12747 |
| 222 | +class DummyCallableBySubmodule: |
| 223 | + """ |
| 224 | + Callable group offloading pinner that pins first and last DummyBlock |
| 225 | + called in the program by callable(submodule) |
| 226 | + """ |
| 227 | + |
| 228 | + def __init__(self, pin_targets: Iterable[torch.nn.Module]) -> None: |
| 229 | + self.pin_targets = set(pin_targets) |
| 230 | + self.calls_track = [] # testing only |
| 231 | + |
| 232 | + def __call__(self, submodule: torch.nn.Module) -> bool: |
| 233 | + self.calls_track.append(submodule) |
| 234 | + return self._normalize_module_type(submodule) in self.pin_targets |
| 235 | + |
| 236 | + def _normalize_module_type(self, obj: Any) -> Optional[torch.nn.Module]: |
| 237 | + # group might be a single module, or a container of modules |
| 238 | + # The group-offloading code may pass either: |
| 239 | + # - a single `torch.nn.Module`, or |
| 240 | + # - a container (list/tuple) of modules. |
| 241 | + |
| 242 | + # Only return a module when the mapping is unambiguous: |
| 243 | + # - if `obj` is a module -> return it |
| 244 | + # - if `obj` is a list/tuple containing exactly one module -> return that module |
| 245 | + # - otherwise -> return None (won't be considered as a target candidate) |
| 246 | + if isinstance(obj, torch.nn.Module): |
| 247 | + return obj |
| 248 | + if isinstance(obj, (list, tuple)): |
| 249 | + mods = [m for m in obj if isinstance(m, torch.nn.Module)] |
| 250 | + return mods[0] if len(mods) == 1 else None |
| 251 | + return None |
| 252 | + |
| 253 | + |
| 254 | +class DummyCallableByNameSubmodule(DummyCallableBySubmodule): |
| 255 | + """ |
| 256 | + Callable group offloading pinner that pins first and last DummyBlock |
| 257 | + Same behaviour with DummyCallableBySubmodule, only with different call signature |
| 258 | + called in the program by callable(name, submodule) |
| 259 | + """ |
| 260 | + |
| 261 | + def __call__(self, name: str, submodule: torch.nn.Module) -> bool: |
| 262 | + self.calls_track.append((name, submodule)) |
| 263 | + return self._normalize_module_type(submodule) in self.pin_targets |
| 264 | + |
| 265 | + |
| 266 | +class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule): |
| 267 | + """ |
| 268 | + Callable group offloading pinner that pins first and last DummyBlock. |
| 269 | + Same behaviour with DummyCallableBySubmodule, only with different call signature |
| 270 | + Called in the program by callable(name, submodule, idx) |
| 271 | + """ |
| 272 | + |
| 273 | + def __call__(self, name: str, submodule: torch.nn.Module, idx: int) -> bool: |
| 274 | + self.calls_track.append((name, submodule, idx)) |
| 275 | + return self._normalize_module_type(submodule) in self.pin_targets |
| 276 | + |
| 277 | + |
| 278 | +class DummyInvalidCallable(DummyCallableBySubmodule): |
| 279 | + """ |
| 280 | + Callable group offloading pinner that uses invalid call signature |
| 281 | + """ |
| 282 | + |
| 283 | + def __call__(self, name: str, submodule: torch.nn.Module, idx: int, extra: Any) -> bool: |
| 284 | + self.calls_track.append((name, submodule, idx, extra)) |
| 285 | + return self._normalize_module_type(submodule) in self.pin_targets |
219 | 286 |
|
220 | 287 |
|
221 | 288 | @require_torch_accelerator |
@@ -566,3 +633,164 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non |
566 | 633 | "layers_per_block": 1, |
567 | 634 | } |
568 | 635 | return init_dict |
| 636 | + |
| 637 | + def test_block_level_offloading_with_pin_groups_stay_on_device(self): |
| 638 | + if torch.device(torch_device).type not in ["cuda", "xpu"]: |
| 639 | + return |
| 640 | + |
| 641 | + def assert_all_modules_on_expected_device( |
| 642 | + modules: Sequence[torch.nn.Module], expected_device: Union[torch.device, str], header_error_msg: str = "" |
| 643 | + ) -> None: |
| 644 | + def first_param_device(modules: torch.nn.Module) -> torch.device: |
| 645 | + p = next(modules.parameters(), None) |
| 646 | + self.assertIsNotNone(p, f"No parameters found for module {modules}") |
| 647 | + return p.device |
| 648 | + |
| 649 | + if isinstance(expected_device, torch.device): |
| 650 | + expected_device = expected_device.type |
| 651 | + |
| 652 | + bad = [] |
| 653 | + for i, m in enumerate(modules): |
| 654 | + dev_type = first_param_device(m).type |
| 655 | + if dev_type != expected_device: |
| 656 | + bad.append((i, m.__class__.__name__, dev_type)) |
| 657 | + self.assertTrue( |
| 658 | + len(bad) == 0, |
| 659 | + (header_error_msg + "\n" if header_error_msg else "") |
| 660 | + + f"Expected all modules on {expected_device}, but found mismatches: {bad}", |
| 661 | + ) |
| 662 | + |
| 663 | + def get_param_modules_from_execution_order(model: DummyModel) -> List[torch.nn.Module]: |
| 664 | + model.eval() |
| 665 | + root_registry = HookRegistry.check_if_exists_or_initialize(model) |
| 666 | + |
| 667 | + lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading") |
| 668 | + self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered") |
| 669 | + |
| 670 | + # record execution order with first forward |
| 671 | + with torch.no_grad(): |
| 672 | + model(self.input) |
| 673 | + |
| 674 | + mods = [m for _, m in lazy_hook.execution_order] |
| 675 | + param_modules = [m for m in mods if next(m.parameters(), None) is not None] |
| 676 | + return param_modules |
| 677 | + |
| 678 | + def assert_callables_offloading_tests( |
| 679 | + param_modules: Sequence[torch.nn.Module], |
| 680 | + callable: Any, |
| 681 | + header_error_msg: str = "", |
| 682 | + ) -> None: |
| 683 | + pinned_modules = [m for m in param_modules if m in callable.pin_targets] |
| 684 | + unpinned_modules = [m for m in param_modules if m not in callable.pin_targets] |
| 685 | + self.assertTrue( |
| 686 | + len(callable.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once" |
| 687 | + ) |
| 688 | + assert_all_modules_on_expected_device( |
| 689 | + pinned_modules, torch_device, f"{header_error_msg}: pinned blocks should stay on device" |
| 690 | + ) |
| 691 | + assert_all_modules_on_expected_device( |
| 692 | + unpinned_modules, "cpu", f"{header_error_msg}: unpinned blocks should be offloaded" |
| 693 | + ) |
| 694 | + |
| 695 | + default_parameters = { |
| 696 | + "onload_device": torch_device, |
| 697 | + "offload_type": "block_level", |
| 698 | + "num_blocks_per_group": 1, |
| 699 | + "use_stream": True, |
| 700 | + } |
| 701 | + model_default_no_pin = self.get_model() |
| 702 | + model_default_no_pin.enable_group_offload(**default_parameters) |
| 703 | + param_modules = get_param_modules_from_execution_order(model_default_no_pin) |
| 704 | + assert_all_modules_on_expected_device( |
| 705 | + param_modules, |
| 706 | + expected_device="cpu", |
| 707 | + header_error_msg="default pin_groups: expected ALL modules offloaded to CPU", |
| 708 | + ) |
| 709 | + |
| 710 | + model_pin_all = self.get_model() |
| 711 | + model_pin_all.enable_group_offload( |
| 712 | + **default_parameters, |
| 713 | + pin_groups="all", |
| 714 | + ) |
| 715 | + param_modules = get_param_modules_from_execution_order(model_pin_all) |
| 716 | + assert_all_modules_on_expected_device( |
| 717 | + param_modules, |
| 718 | + expected_device=torch_device, |
| 719 | + header_error_msg="pin_groups = all: expected ALL layers on accelerator device", |
| 720 | + ) |
| 721 | + |
| 722 | + model_pin_first_last = self.get_model() |
| 723 | + model_pin_first_last.enable_group_offload( |
| 724 | + **default_parameters, |
| 725 | + pin_groups="first_last", |
| 726 | + ) |
| 727 | + param_modules = get_param_modules_from_execution_order(model_pin_first_last) |
| 728 | + assert_all_modules_on_expected_device( |
| 729 | + [param_modules[0], param_modules[-1]], |
| 730 | + expected_device=torch_device, |
| 731 | + header_error_msg="pin_groups = first_last: expected first and last layers on accelerator device", |
| 732 | + ) |
| 733 | + assert_all_modules_on_expected_device( |
| 734 | + param_modules[1:-1], |
| 735 | + expected_device="cpu", |
| 736 | + header_error_msg="pin_groups = first_last: expected ALL middle layers offloaded to CPU", |
| 737 | + ) |
| 738 | + |
| 739 | + model = self.get_model() |
| 740 | + callable_by_submodule = DummyCallableBySubmodule(pin_targets=[model.blocks[0], model.blocks[-1]]) |
| 741 | + model.enable_group_offload(**default_parameters, pin_groups=callable_by_submodule) |
| 742 | + param_modules = get_param_modules_from_execution_order(model) |
| 743 | + assert_callables_offloading_tests( |
| 744 | + param_modules, callable_by_submodule, header_error_msg="pin_groups with callable(submodule)" |
| 745 | + ) |
| 746 | + |
| 747 | + model = self.get_model() |
| 748 | + callable_by_name_submodule = DummyCallableByNameSubmodule(pin_targets=[model.blocks[0], model.blocks[-1]]) |
| 749 | + model.enable_group_offload(**default_parameters, pin_groups=callable_by_name_submodule) |
| 750 | + param_modules = get_param_modules_from_execution_order(model) |
| 751 | + assert_callables_offloading_tests( |
| 752 | + param_modules, callable_by_name_submodule, header_error_msg="pin_groups with callable(name, submodule)" |
| 753 | + ) |
| 754 | + |
| 755 | + model = self.get_model() |
| 756 | + callable_by_name_submodule_idx = DummyCallableByNameSubmoduleIdx( |
| 757 | + pin_targets=[model.blocks[0], model.blocks[-1]] |
| 758 | + ) |
| 759 | + model.enable_group_offload(**default_parameters, pin_groups=callable_by_name_submodule_idx) |
| 760 | + param_modules = get_param_modules_from_execution_order(model) |
| 761 | + assert_callables_offloading_tests( |
| 762 | + param_modules, |
| 763 | + callable_by_name_submodule_idx, |
| 764 | + header_error_msg="pin_groups with callable(name, submodule, idx)", |
| 765 | + ) |
| 766 | + |
| 767 | + def test_error_raised_if_pin_groups_received_invalid_value(self): |
| 768 | + default_parameters = { |
| 769 | + "onload_device": torch_device, |
| 770 | + "offload_type": "block_level", |
| 771 | + "num_blocks_per_group": 1, |
| 772 | + "use_stream": True, |
| 773 | + } |
| 774 | + model = self.get_model() |
| 775 | + with self.assertRaisesRegex(ValueError, "`pin_groups` must be None, 'all', 'first_last', or a callable."): |
| 776 | + model.enable_group_offload( |
| 777 | + **default_parameters, |
| 778 | + pin_groups="invalid value", |
| 779 | + ) |
| 780 | + |
| 781 | + def test_error_raised_if_pin_groups_received_invalid_callables(self): |
| 782 | + default_parameters = { |
| 783 | + "onload_device": torch_device, |
| 784 | + "offload_type": "block_level", |
| 785 | + "num_blocks_per_group": 1, |
| 786 | + "use_stream": True, |
| 787 | + } |
| 788 | + model = self.get_model() |
| 789 | + invalid_callable = DummyInvalidCallable(pin_targets=[model.blocks[0], model.blocks[-1]]) |
| 790 | + model.enable_group_offload( |
| 791 | + **default_parameters, |
| 792 | + pin_groups=invalid_callable, |
| 793 | + ) |
| 794 | + with self.assertRaisesRegex(TypeError, r"missing\s+\d+\s+required\s+positional\s+argument(s)?:"): |
| 795 | + with torch.no_grad(): |
| 796 | + model(self.input) |
0 commit comments