Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

@a-r-r-o-w the following two model splitting tests are failing:

FAILED tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py::HunyuanVideoTransformer3DTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argumen...
FAILED tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py::HunyuanVideoTransformer3DTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument

Could you take a look when you have time? There are similar failures in HunuyanVideo transformer model, too, just as an FYI. Also, cc: @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Contributor

It is expected that clean_x_embedder and x_embedder are put on the same device for this to pass. It is because accelerate performs the device allocation for different layers based on their initialization order. Moving the clean_x_embedder layer initialization right below x_embedder, and image_projection layer right below context_embedder, will fix the error for the tests.

diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
index 0331d9934..012a6e532 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -152,9 +152,14 @@ class HunyuanVideoFramepackTransformer3DModel(
 
         # 1. Latent and condition embedders
         self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+        self.clean_x_embedder = None
+        if has_clean_x_embedder:
+            self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
         self.context_embedder = HunyuanVideoTokenRefiner(
             text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
         )
+        # Framepack specific modules
+        self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
         self.time_text_embed = HunyuanVideoConditionEmbedding(
             inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
         )
@@ -186,13 +191,6 @@ class HunyuanVideoFramepackTransformer3DModel(
         self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
         self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
 
-        # Framepack specific modules
-        self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
-
-        self.clean_x_embedder = None
-        if has_clean_x_embedder:
-            self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
-
         self.gradient_checkpointing = False
 
     def forward(

But, this is not a "correct" fix in the general case. We need to put in device handling code in the concatenate statements for it to work as expected in the correct way. Something like:

hidden_states = torch.cat([latents_clean.to(hidden_states), hidden_states], dim=1)

It makes the code look unnecessarily complicated IMO since it is expected that these would already be on the correct device/dtype in the single GPU case. If we'd like to make these changes anyway, LMK and I'll open a PR.

@sayakpaul
Copy link
Member Author

But, this is not a "correct" fix in the general case. We need to put in device handling code in the concatenate statements for it to work as expected in the correct way. Something like:

Exactly why I didn't make these changes because I strongly echo you opinions on it.

So, given that, I think it's still preferable to go with the other option you mentioned i.e., corresponding to the initialization order.

@a-r-r-o-w
Copy link
Contributor

a-r-r-o-w commented May 10, 2025

#11535 should hopefully fix the error you're seeing for Framepack. For Hunyuan Video, the following patch is required:

--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -112,11 +112,12 @@ class HunyuanVideoAttnProcessor2_0:
             if attn.norm_added_k is not None:
                 encoder_key = attn.norm_added_k(encoder_key)
 
-            query = torch.cat([query, encoder_query], dim=2)
-            key = torch.cat([key, encoder_key], dim=2)
-            value = torch.cat([value, encoder_value], dim=2)
+            query = torch.cat([query, encoder_query.to(query)], dim=2)
+            key = torch.cat([key, encoder_key.to(key)], dim=2)
+            value = torch.cat([value, encoder_value.to(value)], dim=2)
 
         # 5. Attention
+        key, value, attention_mask = (x.to(query.device) for x in (key, value, attention_mask))
         hidden_states = F.scaled_dot_product_attention(
             query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         )
@@ -865,8 +866,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
     _supports_gradient_checkpointing = True
     _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
     _no_split_modules = [
+        "HunyuanVideoConditionEmbedding",
         "HunyuanVideoTransformerBlock",
         "HunyuanVideoSingleTransformerBlock",
+        "HunyuanVideoTokenReplaceTransformerBlock",
+        "HunyuanVideoTokenReplaceSingleTransformerBlock",
         "HunyuanVideoPatchEmbed",
         "HunyuanVideoTokenRefiner",
     ]

I think no split modules changes are okay, but the changes to attention processor seem to complicate easily readable code (same reasoning as not using the "correct" fix mentioned above). I don't think the model parallelism implementation was really meant to handle complex cases like this, similar to how group offloading does not really work as expected with MoE implementation. Probably better to skip the test making note of why it would fail, but up to you.

Edit: The model parallel implementation can handle this if attention processors were nn.Module but since they are just a wrapper class, it does not have the necessary device-modifying hooks registered

@sayakpaul
Copy link
Member Author

I don't think the model parallelism implementation was really meant to handle complex cases like this, similar to how group offloading does not really work as expected with MoE implementation. Probably better to skip the test making note of why it would fail, but up to you.

Thanks, would it be possible to skip them accordingly and batch in a separate PR?

I have confirmed that your fixes in #11535 solve my initial issue. So, please have a look at this PR and LMK.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will skip in a separate PR

@sayakpaul sayakpaul merged commit 01abfc8 into main May 11, 2025
16 checks passed
@sayakpaul sayakpaul deleted the framepack-transformer-tests branch May 11, 2025 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants