1717import unittest
1818
1919import torch
20+ from parameterized import parameterized
2021
22+ from diffusers .hooks import HookRegistry , ModelHook
2123from diffusers .models import ModelMixin
2224from diffusers .pipelines .pipeline_utils import DiffusionPipeline
2325from diffusers .utils import get_logger
@@ -99,6 +101,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99101 return x
100102
101103
104+ # Test for https://github.com/huggingface/diffusers/pull/12077
105+ class DummyModelWithLayerNorm (ModelMixin ):
106+ def __init__ (self , in_features : int , hidden_features : int , out_features : int , num_layers : int ) -> None :
107+ super ().__init__ ()
108+
109+ self .linear_1 = torch .nn .Linear (in_features , hidden_features )
110+ self .activation = torch .nn .ReLU ()
111+ self .blocks = torch .nn .ModuleList (
112+ [DummyBlock (hidden_features , hidden_features , hidden_features ) for _ in range (num_layers )]
113+ )
114+ self .layer_norm = torch .nn .LayerNorm (hidden_features , elementwise_affine = True )
115+ self .linear_2 = torch .nn .Linear (hidden_features , out_features )
116+
117+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
118+ x = self .linear_1 (x )
119+ x = self .activation (x )
120+ for block in self .blocks :
121+ x = block (x )
122+ x = self .layer_norm (x )
123+ x = self .linear_2 (x )
124+ return x
125+
126+
102127class DummyPipeline (DiffusionPipeline ):
103128 model_cpu_offload_seq = "model"
104129
@@ -113,6 +138,16 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
113138 return x
114139
115140
141+ class LayerOutputTrackerHook (ModelHook ):
142+ def __init__ (self ):
143+ super ().__init__ ()
144+ self .outputs = []
145+
146+ def post_forward (self , module , output ):
147+ self .outputs .append (output )
148+ return output
149+
150+
116151@require_torch_accelerator
117152class GroupOffloadTests (unittest .TestCase ):
118153 in_features = 64
@@ -258,6 +293,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
258293 def test_block_level_stream_with_invocation_order_different_from_initialization_order (self ):
259294 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
260295 return
296+
261297 model = DummyModelWithMultipleBlocks (
262298 in_features = self .in_features ,
263299 hidden_features = self .hidden_features ,
@@ -274,3 +310,56 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_
274310
275311 with context :
276312 model (self .input )
313+
314+ @parameterized .expand ([("block_level" ,), ("leaf_level" ,)])
315+ def test_block_level_offloading_with_parameter_only_module_group (self , offload_type : str ):
316+ if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
317+ return
318+
319+ def apply_layer_output_tracker_hook (model : DummyModelWithLayerNorm ):
320+ for name , module in model .named_modules ():
321+ registry = HookRegistry .check_if_exists_or_initialize (module )
322+ hook = LayerOutputTrackerHook ()
323+ registry .register_hook (hook , "layer_output_tracker" )
324+
325+ model_ref = DummyModelWithLayerNorm (128 , 256 , 128 , 2 )
326+ model = DummyModelWithLayerNorm (128 , 256 , 128 , 2 )
327+
328+ model .load_state_dict (model_ref .state_dict (), strict = True )
329+
330+ model_ref .to (torch_device )
331+ model .enable_group_offload (torch_device , offload_type = offload_type , num_blocks_per_group = 1 , use_stream = True )
332+
333+ apply_layer_output_tracker_hook (model_ref )
334+ apply_layer_output_tracker_hook (model )
335+
336+ x = torch .randn (2 , 128 ).to (torch_device )
337+
338+ out_ref = model_ref (x )
339+ out = model (x )
340+ self .assertTrue (torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match." )
341+
342+ num_repeats = 4
343+ for i in range (num_repeats ):
344+ out_ref = model_ref (x )
345+ out = model (x )
346+
347+ self .assertTrue (torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match after multiple invocations." )
348+
349+ for (ref_name , ref_module ), (name , module ) in zip (model_ref .named_modules (), model .named_modules ()):
350+ assert ref_name == name
351+ if not isinstance (ref_module , (torch .nn .Linear , torch .nn .LayerNorm )):
352+ continue
353+ ref_outputs = (
354+ HookRegistry .check_if_exists_or_initialize (ref_module ).get_hook ("layer_output_tracker" ).outputs
355+ )
356+ outputs = HookRegistry .check_if_exists_or_initialize (module ).get_hook ("layer_output_tracker" ).outputs
357+ cumulated_absmax = 0.0
358+ for i in range (len (outputs )):
359+ diff = ref_outputs [0 ] - outputs [i ]
360+ absdiff = diff .abs ()
361+ absmax = absdiff .max ().item ()
362+ cumulated_absmax += absmax
363+ self .assertLess (
364+ cumulated_absmax , 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
365+ )
0 commit comments