Skip to content

Commit ac74eed

Browse files
committed
add test
1 parent 8c6edb3 commit ac74eed

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import unittest
1818

1919
import torch
20+
from parameterized import parameterized
2021

22+
from diffusers.hooks import HookRegistry, ModelHook
2123
from diffusers.models import ModelMixin
2224
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
2325
from diffusers.utils import get_logger
@@ -99,6 +101,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99101
return x
100102

101103

104+
# Test for https://github.com/huggingface/diffusers/pull/12077
105+
class DummyModelWithLayerNorm(ModelMixin):
106+
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
107+
super().__init__()
108+
109+
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
110+
self.activation = torch.nn.ReLU()
111+
self.blocks = torch.nn.ModuleList(
112+
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
113+
)
114+
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
115+
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
116+
117+
def forward(self, x: torch.Tensor) -> torch.Tensor:
118+
x = self.linear_1(x)
119+
x = self.activation(x)
120+
for block in self.blocks:
121+
x = block(x)
122+
x = self.layer_norm(x)
123+
x = self.linear_2(x)
124+
return x
125+
126+
102127
class DummyPipeline(DiffusionPipeline):
103128
model_cpu_offload_seq = "model"
104129

@@ -113,6 +138,16 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
113138
return x
114139

115140

141+
class LayerOutputTrackerHook(ModelHook):
142+
def __init__(self):
143+
super().__init__()
144+
self.outputs = []
145+
146+
def post_forward(self, module, output):
147+
self.outputs.append(output)
148+
return output
149+
150+
116151
@require_torch_accelerator
117152
class GroupOffloadTests(unittest.TestCase):
118153
in_features = 64
@@ -258,6 +293,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
258293
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
259294
if torch.device(torch_device).type not in ["cuda", "xpu"]:
260295
return
296+
261297
model = DummyModelWithMultipleBlocks(
262298
in_features=self.in_features,
263299
hidden_features=self.hidden_features,
@@ -274,3 +310,56 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_
274310

275311
with context:
276312
model(self.input)
313+
314+
@parameterized.expand([("block_level",), ("leaf_level",)])
315+
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
316+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
317+
return
318+
319+
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
320+
for name, module in model.named_modules():
321+
registry = HookRegistry.check_if_exists_or_initialize(module)
322+
hook = LayerOutputTrackerHook()
323+
registry.register_hook(hook, "layer_output_tracker")
324+
325+
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
326+
model = DummyModelWithLayerNorm(128, 256, 128, 2)
327+
328+
model.load_state_dict(model_ref.state_dict(), strict=True)
329+
330+
model_ref.to(torch_device)
331+
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
332+
333+
apply_layer_output_tracker_hook(model_ref)
334+
apply_layer_output_tracker_hook(model)
335+
336+
x = torch.randn(2, 128).to(torch_device)
337+
338+
out_ref = model_ref(x)
339+
out = model(x)
340+
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
341+
342+
num_repeats = 4
343+
for i in range(num_repeats):
344+
out_ref = model_ref(x)
345+
out = model(x)
346+
347+
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
348+
349+
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
350+
assert ref_name == name
351+
if not isinstance(ref_module, (torch.nn.Linear, torch.nn.LayerNorm)):
352+
continue
353+
ref_outputs = (
354+
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
355+
)
356+
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
357+
cumulated_absmax = 0.0
358+
for i in range(len(outputs)):
359+
diff = ref_outputs[0] - outputs[i]
360+
absdiff = diff.abs()
361+
absmax = absdiff.max().item()
362+
cumulated_absmax += absmax
363+
self.assertLess(
364+
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
365+
)

0 commit comments

Comments
 (0)