Skip to content

Commit f848feb

Browse files
authored
feat: allow sharding for auraflow. (#8853)
1 parent b382550 commit f848feb

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
274274
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
275275
"""
276276

277+
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
277278
_supports_gradient_checkpointing = True
278279

279280
@register_to_config

tests/models/transformers/test_models_transformer_aura_flow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = AuraFlowTransformer2DModel
3131
main_input_name = "hidden_states"
32+
# We override the items here because the transformer under consideration is small.
33+
model_split_percents = [0.7, 0.6, 0.6]
3234

3335
@property
3436
def dummy_input(self):

0 commit comments

Comments
 (0)