Skip to content

Commit 24f9273

Browse files
committed
address review comments
1 parent f227e15 commit 24f9273

15 files changed

+17
-44
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _apply_group_offloading_block_level(
343343
for i in range(0, len(submodule), num_blocks_per_group):
344344
current_modules = submodule[i : i + num_blocks_per_group]
345345
group = ModuleGroup(
346-
modules=submodule[i : i + num_blocks_per_group],
346+
modules=current_modules,
347347
offload_device=offload_device,
348348
onload_device=onload_device,
349349
offload_leader=current_modules[-1],

src/diffusers/models/autoencoders/autoencoder_oobleck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
317317
"""
318318

319319
_supports_gradient_checkpointing = False
320+
_supports_group_offloading = False
320321

321322
@register_to_config
322323
def __init__(

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6868
```
6969
"""
7070

71+
_supports_group_offloading = False
72+
7173
@register_to_config
7274
def __init__(
7375
self,

src/diffusers/models/autoencoders/vq_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
7272
"""
7373

7474
_skip_layerwise_casting_patterns = ["quantize"]
75+
_supports_group_offloading = False
7576

7677
@register_to_config
7778
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
174174
_no_split_modules = None
175175
_keep_in_fp32_modules = None
176176
_skip_layerwise_casting_patterns = None
177+
_supports_group_offloading = True
177178

178179
def __init__(self):
179180
super().__init__()

src/diffusers/models/transformers/dit_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
6666

6767
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
6868
_supports_gradient_checkpointing = True
69+
_supports_group_offloading = False
6970

7071
@register_to_config
7172
def __init__(

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
245245
"""
246246

247247
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
248+
_supports_group_offloading = False
248249

249250
@register_to_config
250251
def __init__(

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,25 +1020,19 @@ def _execution_device(self):
10201020
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
10211021
Accelerate's module hooks.
10221022
"""
1023-
diffusers_hook_device = None
1023+
# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
1024+
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
1025+
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
10241026
for name, model in self.components.items():
10251027
if not isinstance(model, torch.nn.Module):
10261028
continue
1027-
10281029
for submodule in model.modules():
10291030
if not hasattr(submodule, "_diffusers_hook"):
10301031
continue
10311032
registry = submodule._diffusers_hook
10321033
hook = registry.get_hook("group_offloading")
10331034
if hook is not None:
1034-
diffusers_hook_device = hook.group.onload_device
1035-
break
1036-
1037-
if diffusers_hook_device is not None:
1038-
break
1039-
1040-
if diffusers_hook_device is not None:
1041-
return diffusers_hook_device
1035+
return hook.group.onload_device
10421036

10431037
for name, model in self.components.items():
10441038
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,6 @@ def test_layerwise_casting_inference(self):
132132
def test_layerwise_casting_memory(self):
133133
pass
134134

135-
@unittest.skip(
136-
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
137-
"cast the module weights to the expected device (as required by forward pass). As a result, forward pass errors out. To fix:\n"
138-
"1. Make sure `nn::Module::to(device)` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
139-
"2. Unskip this test."
140-
)
141-
def test_group_offloading(self):
142-
pass
143-
144135

145136
@slow
146137
class AutoencoderOobleckIntegrationTests(unittest.TestCase):

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,6 @@ def test_enable_disable_slicing(self):
155155
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
156156
)
157157

158-
@unittest.skip("Not quite sure why this test fails and unable to debug.")
159-
def test_group_offloading(self):
160-
pass
161-
162158

163159
@slow
164160
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)