Skip to content

Commit b1ded13

Browse files
committed
set supports gradient checkpointing to true where necessary; add missing no split modules
1 parent 074798b commit b1ded13

File tree

8 files changed

+15
-1
lines changed

8 files changed

+15
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ def _get_signature_keys(cls, obj):
11961196
# Adapted from `transformers` modeling_utils.py
11971197
def _get_no_split_modules(self, device_map: str):
11981198
"""
1199-
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
1199+
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
12001200
get the underlying `_no_split_modules`.
12011201
12021202
Args:

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
210210
"""
211211

212212
_supports_gradient_checkpointing = True
213+
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
213214

214215
@register_to_config
215216
def __init__(

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
2929
_supports_gradient_checkpointing = True
30+
_no_split_modules = ["BasicTransformerBlock"]
3031

3132
"""
3233
A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
221221
Scaling factor to apply in 3D positional embeddings across time dimension.
222222
"""
223223

224+
_supports_gradient_checkpointing = True
225+
_no_split_modules = ["AllegroTransformerBlock"]
226+
224227
@register_to_config
225228
def __init__(
226229
self,

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
166166
"""
167167

168168
_supports_gradient_checkpointing = True
169+
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
169170

170171
@register_to_config
171172
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
540540
"""
541541

542542
_supports_gradient_checkpointing = True
543+
_no_split_modules = [
544+
"HunyuanVideoTransformerBlock",
545+
"HunyuanVideoSingleTransformerBlock",
546+
"HunyuanVideoPatchEmbed",
547+
"HunyuanVideoTokenRefiner",
548+
]
543549

544550
@register_to_config
545551
def __init__(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
295295
"""
296296

297297
_supports_gradient_checkpointing = True
298+
_no_split_modules = ["LTXTransformerBlock"]
298299

299300
@register_to_config
300301
def __init__(

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
125125
"""
126126

127127
_supports_gradient_checkpointing = True
128+
_no_split_modules = ["JointTransformerBlock", "SD3SingleTransformerBlock"]
128129

129130
@register_to_config
130131
def __init__(

0 commit comments

Comments
 (0)