Skip to content

Commit 2e8f538

Browse files
committed
restored feature/group-offload-tests
1 parent af61b9c commit 2e8f538

File tree

1 file changed

+228
-0
lines changed

1 file changed

+228
-0
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
torch_device,
3535
)
3636

37+
from typing import Any, Iterable, List, Optional, Sequence, Union
3738

3839
class DummyBlock(torch.nn.Module):
3940
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
@@ -216,6 +217,72 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
216217
x = block(x)
217218
x = self.norm(x)
218219
return x
220+
221+
# Test for https://github.com/huggingface/diffusers/pull/12747
222+
class DummyCallableBySubmodule:
223+
"""
224+
Callable group offloading pinner that pins first and last DummyBlock
225+
called in the program by callable(submodule)
226+
"""
227+
228+
def __init__(self, pin_targets: Iterable[torch.nn.Module]) -> None:
229+
self.pin_targets = set(pin_targets)
230+
self.calls_track = [] # testing only
231+
232+
def __call__(self, submodule: torch.nn.Module) -> bool:
233+
self.calls_track.append(submodule)
234+
return self._normalize_module_type(submodule) in self.pin_targets
235+
236+
def _normalize_module_type(self, obj: Any) -> Optional[torch.nn.Module]:
237+
# group might be a single module, or a container of modules
238+
# The group-offloading code may pass either:
239+
# - a single `torch.nn.Module`, or
240+
# - a container (list/tuple) of modules.
241+
242+
# Only return a module when the mapping is unambiguous:
243+
# - if `obj` is a module -> return it
244+
# - if `obj` is a list/tuple containing exactly one module -> return that module
245+
# - otherwise -> return None (won't be considered as a target candidate)
246+
if isinstance(obj, torch.nn.Module):
247+
return obj
248+
if isinstance(obj, (list, tuple)):
249+
mods = [m for m in obj if isinstance(m, torch.nn.Module)]
250+
return mods[0] if len(mods) == 1 else None
251+
return None
252+
253+
254+
class DummyCallableByNameSubmodule(DummyCallableBySubmodule):
255+
"""
256+
Callable group offloading pinner that pins first and last DummyBlock
257+
Same behaviour with DummyCallableBySubmodule, only with different call signature
258+
called in the program by callable(name, submodule)
259+
"""
260+
261+
def __call__(self, name: str, submodule: torch.nn.Module) -> bool:
262+
self.calls_track.append((name, submodule))
263+
return self._normalize_module_type(submodule) in self.pin_targets
264+
265+
266+
class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
267+
"""
268+
Callable group offloading pinner that pins first and last DummyBlock.
269+
Same behaviour with DummyCallableBySubmodule, only with different call signature
270+
Called in the program by callable(name, submodule, idx)
271+
"""
272+
273+
def __call__(self, name: str, submodule: torch.nn.Module, idx: int) -> bool:
274+
self.calls_track.append((name, submodule, idx))
275+
return self._normalize_module_type(submodule) in self.pin_targets
276+
277+
278+
class DummyInvalidCallable(DummyCallableBySubmodule):
279+
"""
280+
Callable group offloading pinner that uses invalid call signature
281+
"""
282+
283+
def __call__(self, name: str, submodule: torch.nn.Module, idx: int, extra: Any) -> bool:
284+
self.calls_track.append((name, submodule, idx, extra))
285+
return self._normalize_module_type(submodule) in self.pin_targets
219286

220287

221288
@require_torch_accelerator
@@ -566,3 +633,164 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non
566633
"layers_per_block": 1,
567634
}
568635
return init_dict
636+
637+
def test_block_level_offloading_with_pin_groups_stay_on_device(self):
638+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
639+
return
640+
641+
def assert_all_modules_on_expected_device(
642+
modules: Sequence[torch.nn.Module], expected_device: Union[torch.device, str], header_error_msg: str = ""
643+
) -> None:
644+
def first_param_device(modules: torch.nn.Module) -> torch.device:
645+
p = next(modules.parameters(), None)
646+
self.assertIsNotNone(p, f"No parameters found for module {modules}")
647+
return p.device
648+
649+
if isinstance(expected_device, torch.device):
650+
expected_device = expected_device.type
651+
652+
bad = []
653+
for i, m in enumerate(modules):
654+
dev_type = first_param_device(m).type
655+
if dev_type != expected_device:
656+
bad.append((i, m.__class__.__name__, dev_type))
657+
self.assertTrue(
658+
len(bad) == 0,
659+
(header_error_msg + "\n" if header_error_msg else "")
660+
+ f"Expected all modules on {expected_device}, but found mismatches: {bad}",
661+
)
662+
663+
def get_param_modules_from_execution_order(model: DummyModel) -> List[torch.nn.Module]:
664+
model.eval()
665+
root_registry = HookRegistry.check_if_exists_or_initialize(model)
666+
667+
lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading")
668+
self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered")
669+
670+
# record execution order with first forward
671+
with torch.no_grad():
672+
model(self.input)
673+
674+
mods = [m for _, m in lazy_hook.execution_order]
675+
param_modules = [m for m in mods if next(m.parameters(), None) is not None]
676+
return param_modules
677+
678+
def assert_callables_offloading_tests(
679+
param_modules: Sequence[torch.nn.Module],
680+
callable: Any,
681+
header_error_msg: str = "",
682+
) -> None:
683+
pinned_modules = [m for m in param_modules if m in callable.pin_targets]
684+
unpinned_modules = [m for m in param_modules if m not in callable.pin_targets]
685+
self.assertTrue(
686+
len(callable.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once"
687+
)
688+
assert_all_modules_on_expected_device(
689+
pinned_modules, torch_device, f"{header_error_msg}: pinned blocks should stay on device"
690+
)
691+
assert_all_modules_on_expected_device(
692+
unpinned_modules, "cpu", f"{header_error_msg}: unpinned blocks should be offloaded"
693+
)
694+
695+
default_parameters = {
696+
"onload_device": torch_device,
697+
"offload_type": "block_level",
698+
"num_blocks_per_group": 1,
699+
"use_stream": True,
700+
}
701+
model_default_no_pin = self.get_model()
702+
model_default_no_pin.enable_group_offload(**default_parameters)
703+
param_modules = get_param_modules_from_execution_order(model_default_no_pin)
704+
assert_all_modules_on_expected_device(
705+
param_modules,
706+
expected_device="cpu",
707+
header_error_msg="default pin_groups: expected ALL modules offloaded to CPU",
708+
)
709+
710+
model_pin_all = self.get_model()
711+
model_pin_all.enable_group_offload(
712+
**default_parameters,
713+
pin_groups="all",
714+
)
715+
param_modules = get_param_modules_from_execution_order(model_pin_all)
716+
assert_all_modules_on_expected_device(
717+
param_modules,
718+
expected_device=torch_device,
719+
header_error_msg="pin_groups = all: expected ALL layers on accelerator device",
720+
)
721+
722+
model_pin_first_last = self.get_model()
723+
model_pin_first_last.enable_group_offload(
724+
**default_parameters,
725+
pin_groups="first_last",
726+
)
727+
param_modules = get_param_modules_from_execution_order(model_pin_first_last)
728+
assert_all_modules_on_expected_device(
729+
[param_modules[0], param_modules[-1]],
730+
expected_device=torch_device,
731+
header_error_msg="pin_groups = first_last: expected first and last layers on accelerator device",
732+
)
733+
assert_all_modules_on_expected_device(
734+
param_modules[1:-1],
735+
expected_device="cpu",
736+
header_error_msg="pin_groups = first_last: expected ALL middle layers offloaded to CPU",
737+
)
738+
739+
model = self.get_model()
740+
callable_by_submodule = DummyCallableBySubmodule(pin_targets=[model.blocks[0], model.blocks[-1]])
741+
model.enable_group_offload(**default_parameters, pin_groups=callable_by_submodule)
742+
param_modules = get_param_modules_from_execution_order(model)
743+
assert_callables_offloading_tests(
744+
param_modules, callable_by_submodule, header_error_msg="pin_groups with callable(submodule)"
745+
)
746+
747+
model = self.get_model()
748+
callable_by_name_submodule = DummyCallableByNameSubmodule(pin_targets=[model.blocks[0], model.blocks[-1]])
749+
model.enable_group_offload(**default_parameters, pin_groups=callable_by_name_submodule)
750+
param_modules = get_param_modules_from_execution_order(model)
751+
assert_callables_offloading_tests(
752+
param_modules, callable_by_name_submodule, header_error_msg="pin_groups with callable(name, submodule)"
753+
)
754+
755+
model = self.get_model()
756+
callable_by_name_submodule_idx = DummyCallableByNameSubmoduleIdx(
757+
pin_targets=[model.blocks[0], model.blocks[-1]]
758+
)
759+
model.enable_group_offload(**default_parameters, pin_groups=callable_by_name_submodule_idx)
760+
param_modules = get_param_modules_from_execution_order(model)
761+
assert_callables_offloading_tests(
762+
param_modules,
763+
callable_by_name_submodule_idx,
764+
header_error_msg="pin_groups with callable(name, submodule, idx)",
765+
)
766+
767+
def test_error_raised_if_pin_groups_received_invalid_value(self):
768+
default_parameters = {
769+
"onload_device": torch_device,
770+
"offload_type": "block_level",
771+
"num_blocks_per_group": 1,
772+
"use_stream": True,
773+
}
774+
model = self.get_model()
775+
with self.assertRaisesRegex(ValueError, "`pin_groups` must be None, 'all', 'first_last', or a callable."):
776+
model.enable_group_offload(
777+
**default_parameters,
778+
pin_groups="invalid value",
779+
)
780+
781+
def test_error_raised_if_pin_groups_received_invalid_callables(self):
782+
default_parameters = {
783+
"onload_device": torch_device,
784+
"offload_type": "block_level",
785+
"num_blocks_per_group": 1,
786+
"use_stream": True,
787+
}
788+
model = self.get_model()
789+
invalid_callable = DummyInvalidCallable(pin_targets=[model.blocks[0], model.blocks[-1]])
790+
model.enable_group_offload(
791+
**default_parameters,
792+
pin_groups=invalid_callable,
793+
)
794+
with self.assertRaisesRegex(TypeError, r"missing\s+\d+\s+required\s+positional\s+argument(s)?:"):
795+
with torch.no_grad():
796+
model(self.input)

0 commit comments

Comments
 (0)