Skip to content

Commit db2fd3b

Browse files
committed
add model tests
1 parent 5ea3d8a commit db2fd3b

File tree

8 files changed

+89
-6
lines changed

8 files changed

+89
-6
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ def _apply_group_offloading_block_level(
387387
cpu_param_dict=None,
388388
onload_self=True,
389389
)
390-
_apply_group_offloading_hook(module, unmatched_group, force_offload, matched_module_groups[0])
390+
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
391+
_apply_group_offloading_hook(module, unmatched_group, force_offload, next_group)
391392

392393

393394
def _apply_group_offloading_leaf_level(
@@ -522,9 +523,13 @@ def _apply_group_offloading_hook(
522523
offload_on_init: bool,
523524
next_group: Optional[ModuleGroup] = None,
524525
) -> None:
525-
hook = GroupOffloadingHook(group, offload_on_init, next_group)
526526
registry = HookRegistry.check_if_exists_or_initialize(module)
527-
registry.register_hook(hook, _GROUP_OFFLOADING)
527+
528+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
529+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
530+
if registry.get_hook(_GROUP_OFFLOADING) is None:
531+
hook = GroupOffloadingHook(group, offload_on_init, next_group)
532+
registry.register_hook(hook, _GROUP_OFFLOADING)
528533

529534

530535
def _apply_lazy_group_offloading_hook(
@@ -533,13 +538,15 @@ def _apply_lazy_group_offloading_hook(
533538
offload_on_init: bool,
534539
next_group: Optional[ModuleGroup] = None,
535540
) -> None:
536-
hook = GroupOffloadingHook(group, offload_on_init, next_group)
537-
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
538541
registry = HookRegistry.check_if_exists_or_initialize(module)
542+
539543
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
540544
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
541545
if registry.get_hook(_GROUP_OFFLOADING) is None:
546+
hook = GroupOffloadingHook(group, offload_on_init, next_group)
542547
registry.register_hook(hook, _GROUP_OFFLOADING)
548+
549+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
543550
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
544551

545552

src/diffusers/hooks/hooks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def __init__(self, module_ref: torch.nn.Module) -> None:
120120

121121
def register_hook(self, hook: ModelHook, name: str) -> None:
122122
if name in self.hooks.keys():
123-
logger.warning(f"Hook with name {name} already exists, replacing it.")
123+
raise ValueError(
124+
f"Hook with name {name} already exists in the registry. Please use a different name or "
125+
f"first remove the existing hook and then add a new one."
126+
)
124127

125128
self._module_ref = hook.initialize_hook(self._module_ref)
126129

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ def test_layerwise_casting_inference(self):
132132
def test_layerwise_casting_memory(self):
133133
pass
134134

135+
@unittest.skip(
136+
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
137+
"cast the module weights to the expected device (as required by forward pass). As a result, forward pass errors out. To fix:\n"
138+
"1. Make sure `nn::Module::to(device)` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
139+
"2. Unskip this test."
140+
)
141+
def test_group_offloading(self):
142+
pass
143+
135144

136145
@slow
137146
class AutoencoderOobleckIntegrationTests(unittest.TestCase):

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ def test_enable_disable_slicing(self):
155155
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
156156
)
157157

158+
@unittest.skip("Not quite sure why this test fails and unable to debug.")
159+
def test_group_offloading(self):
160+
pass
161+
158162

159163
@slow
160164
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):

tests/models/autoencoders/test_models_vq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,7 @@ def test_loss_pretrained(self):
116116
expected_output = torch.tensor([0.1936])
117117
# fmt: on
118118
self.assertTrue(torch.allclose(output, expected_output, atol=1e-3))
119+
120+
@unittest.skip("Group offloading for torch::nn::Embedding layers is not yet supported.")
121+
def test_group_offloading(self):
122+
pass

tests/models/test_modeling_common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from parameterized import parameterized
3838
from requests.exceptions import HTTPError
3939

40+
from diffusers.hooks import apply_group_offloading
4041
from diffusers.models import UNet2DConditionModel
4142
from diffusers.models.attention_processor import (
4243
AttnProcessor,
@@ -1433,6 +1434,45 @@ def get_memory_usage(storage_dtype, compute_dtype):
14331434
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
14341435
)
14351436

1437+
@require_torch_gpu
1438+
def test_group_offloading(self):
1439+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1440+
torch.manual_seed(0)
1441+
1442+
def run_forward(model):
1443+
model.eval()
1444+
with torch.no_grad():
1445+
return model(**inputs_dict)[0]
1446+
1447+
model = self.model_class(**init_dict)
1448+
model.to(torch_device)
1449+
output_without_group_offloading = run_forward(model)
1450+
1451+
torch.manual_seed(0)
1452+
model = self.model_class(**init_dict)
1453+
apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1)
1454+
output_with_group_offloading1 = run_forward(model)
1455+
1456+
torch.manual_seed(0)
1457+
model = self.model_class(**init_dict)
1458+
apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
1459+
output_with_group_offloading2 = run_forward(model)
1460+
1461+
torch.manual_seed(0)
1462+
model = self.model_class(**init_dict)
1463+
apply_group_offloading(model, offload_type="leaf_level")
1464+
output_with_group_offloading3 = run_forward(model)
1465+
1466+
torch.manual_seed(0)
1467+
model = self.model_class(**init_dict)
1468+
apply_group_offloading(model, offload_type="leaf_level", use_stream=True)
1469+
output_with_group_offloading4 = run_forward(model)
1470+
1471+
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
1472+
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
1473+
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
1474+
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
1475+
14361476

14371477
@is_staging_test
14381478
class ModelPushToHubTester(unittest.TestCase):

tests/models/transformers/test_models_dit_transformer2d.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,11 @@ def test_correct_class_remapping_from_pretrained_config(self):
100100
def test_correct_class_remapping(self):
101101
model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer")
102102
assert isinstance(model, DiTTransformer2DModel)
103+
104+
@unittest.skip(
105+
"This model uses a direct call to self.transformer_blocks[0].norm1.emb. This causes attached hooks to not be invoked "
106+
"when block offloading is enabled. In order for it to work, the model should correctly first invoke the forward pass "
107+
"the transformer blocks, so that weights can be onloaded, instead of directly invoking a submodule of the block."
108+
)
109+
def test_group_offloading(self):
110+
pass

tests/models/transformers/test_models_transformer_hunyuan_dit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,11 @@ def test_set_xformers_attn_processor_for_determinism(self):
111111
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
112112
def test_set_attn_processor_for_determinism(self):
113113
pass
114+
115+
@unittest.skip(
116+
"This model uses a direct call to F.multi_head_attention_forward instead of using a torch.nn.Module layer. This "
117+
"usage is not yet supported with group offloading, because the call directly operates on the weights of the module. "
118+
"We attach hooks correctly, but the onloading does not occur because the torch::nn::Module::forward is never invoked."
119+
)
120+
def test_group_offloading(self):
121+
pass

0 commit comments

Comments
 (0)