Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def _get_signature_keys(cls, obj):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]

@register_to_config
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""

_supports_gradient_checkpointing = True

@register_to_config
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""

_supports_gradient_checkpointing = True
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
Expand Down Expand Up @@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
Expand Down
Loading