Skip to content

Commit be76b82

Browse files
committed
added invalid test cases
1 parent 536d8e6 commit be76b82

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule):
204204
def __call__(self, name, submodule, idx):
205205
self.calls_track.append((name, submodule, idx))
206206
return self._normalize_module_type(submodule) in self.pin_targets
207+
208+
# Test for https://github.com/huggingface/diffusers/pull/12747
209+
class DummyInvalidCallable(DummyCallableBySubmodule):
210+
"""
211+
Callable group offloading pinner that uses invalid call signature
212+
"""
213+
def __call__(self, name, submodule, idx, extra):
214+
self.calls_track.append((name, submodule, idx, extra))
215+
return self._normalize_module_type(submodule) in self.pin_targets
207216

208217

209218
@require_torch_accelerator
@@ -420,7 +429,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
420429
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
421430
)
422431

423-
def test_block_level_pin_groups_stay_on_device(self):
432+
def test_block_level_offloading_with_pin_groups_stay_on_device(self):
424433
if torch.device(torch_device).type not in ["cuda", "xpu"]:
425434
return
426435

@@ -538,3 +547,39 @@ def assert_callables_offloading_tests(param_modules,
538547
callable_by_name_submodule_idx,
539548
header_error_msg="pin_groups with callable(name, submodule, idx)")
540549

550+
def test_error_raised_if_pin_groups_received_invalid_value(self):
551+
default_parameters = {
552+
"onload_device": torch_device,
553+
"offload_type": "block_level",
554+
"num_blocks_per_group": 1,
555+
"use_stream": True,
556+
}
557+
model = self.get_model()
558+
with self.assertRaisesRegex(ValueError,
559+
"`pin_groups` must be one of `None`, 'first_last', 'all', or a callable."):
560+
model.enable_group_offload(
561+
**default_parameters,
562+
pin_groups="invalid value",
563+
)
564+
565+
def test_error_raised_if_pin_groups_received_invalid_callables(self):
566+
default_parameters = {
567+
"onload_device": torch_device,
568+
"offload_type": "block_level",
569+
"num_blocks_per_group": 1,
570+
"use_stream": True,
571+
}
572+
model = self.get_model()
573+
invalid_callable = DummyInvalidCallable(pin_targets=[model.blocks[0], model.blocks[-1]])
574+
model.enable_group_offload(
575+
**default_parameters,
576+
pin_groups=invalid_callable,
577+
)
578+
with self.assertRaisesRegex(TypeError,
579+
r"missing\s+\d+\s+required\s+positional\s+argument(s)?:"):
580+
with torch.no_grad():
581+
model(self.input)
582+
583+
584+
585+

0 commit comments

Comments
 (0)