Skip to content

Commit 71ad16b

Browse files
a-r-r-o-wDN6
andauthored
Add _no_split_modules to some models (huggingface#10308)
* set supports gradient checkpointing to true where necessary; add missing no split modules * fix cogvideox tests * update --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent ee7e141 commit 71ad16b

File tree

7 files changed

+14
-4
lines changed

7 files changed

+14
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ def _get_signature_keys(cls, obj):
12141214
# Adapted from `transformers` modeling_utils.py
12151215
def _get_no_split_modules(self, device_map: str):
12161216
"""
1217-
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
1217+
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
12181218
get the underlying `_no_split_modules`.
12191219
12201220
Args:

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
210210
"""
211211

212212
_supports_gradient_checkpointing = True
213+
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
213214

214215
@register_to_config
215216
def __init__(

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
221221
Scaling factor to apply in 3D positional embeddings across time dimension.
222222
"""
223223

224+
_supports_gradient_checkpointing = True
225+
224226
@register_to_config
225227
def __init__(
226228
self,

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
166166
"""
167167

168168
_supports_gradient_checkpointing = True
169+
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
169170

170171
@register_to_config
171172
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
542542
"""
543543

544544
_supports_gradient_checkpointing = True
545+
_no_split_modules = [
546+
"HunyuanVideoTransformerBlock",
547+
"HunyuanVideoSingleTransformerBlock",
548+
"HunyuanVideoPatchEmbed",
549+
"HunyuanVideoTokenRefiner",
550+
]
545551

546552
@register_to_config
547553
def __init__(

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
7171
"out_channels": 4,
7272
"time_embed_dim": 2,
7373
"text_embed_dim": 8,
74-
"num_layers": 1,
74+
"num_layers": 2,
7575
"sample_width": 8,
7676
"sample_height": 8,
7777
"sample_frames": 8,
@@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self):
130130
"out_channels": 4,
131131
"time_embed_dim": 2,
132132
"text_embed_dim": 8,
133-
"num_layers": 1,
133+
"num_layers": 2,
134134
"sample_width": 8,
135135
"sample_height": 8,
136136
"sample_frames": 8,

tests/models/transformers/test_models_transformer_cogview3plus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
7171
init_dict = {
7272
"patch_size": 2,
7373
"in_channels": 4,
74-
"num_layers": 1,
74+
"num_layers": 2,
7575
"attention_head_dim": 4,
7676
"num_attention_heads": 2,
7777
"out_channels": 4,

0 commit comments

Comments
 (0)