Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w commented Apr 29, 2025

Fixes #11431 (review)

Will run the full model for testing in some time.

testing code
import torch
from diffusers import HunyuanVideoTransformer3DModel

@torch.no_grad()
def main():
    device = "cuda"
    dtype = torch.bfloat16
    batch_size = 1
    num_channels = 4
    num_frames = 2
    height = 4
    width = 4
    sequence_length = 8
    
    transformer = HunyuanVideoTransformer3DModel(
        in_channels=4,
        out_channels=4,
        num_attention_heads=2,
        attention_head_dim=10,
        num_layers=2,
        num_single_layers=2,
        num_refiner_layers=1,
        patch_size=1,
        patch_size_t=1,
        guidance_embeds=True,
        text_embed_dim=16,
        pooled_projection_dim=8,
        rope_axes_dim=(2, 4, 4),
    ).to(device=device, dtype=dtype)

    hidden_states = torch.randn(batch_size, num_channels, num_frames, height, width, dtype=dtype, device=device)
    timestep = torch.randint(0, 1000, (batch_size,), dtype=torch.long, device=device)
    encoder_hidden_states = torch.randn(batch_size, sequence_length, 16, dtype=dtype, device=device)
    encoder_attention_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    encoder_attention_mask[:, :sequence_length - 2] = True
    pooled_projections = torch.randn(batch_size, 8, dtype=dtype, device=device)
    guidance = torch.randint(0, 1000, (batch_size,), dtype=torch.long, device=device)

    transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
    output = transformer(
        hidden_states=hidden_states,
        timestep=timestep,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        pooled_projections=pooled_projections,
        guidance=guidance,
        return_dict=False,
    )[0]
    print(output.shape)


if __name__ == "__main__":
    main()

@a-r-r-o-w a-r-r-o-w requested review from Copilot and sayakpaul April 29, 2025 22:24
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves fullgraph compatibility for Hunyuan Video by replacing a loop‑based attention mask construction with a vectorized approach.

  • Replaces manual per‐batch loop with a vectorized masked_fill operation.
  • Updates the attention mask initialization from zeros to ones and adds appropriate unsqueezing for broadcasting.
Comments suppressed due to low confidence (1)

src/diffusers/models/transformers/transformer_hunyuan_video.py:1071

  • [nitpick] The refactored attention mask construction is more efficient; consider adding an inline comment that explains the logic behind initializing with ones and using masked_fill for clarity.
attention_mask = torch.ones(batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid, thanks for working on this!

To confirm, I first checked out the PR branch of #11431 and merged your PR branch there and then ran RUN_SLOW=1 RUN_COMPILE=1 pytest tests/models/transformers/test_models_transformer_hunyuan_video.py -k "test_torch_compile_recompilation_and_graph_break". Everything was green.

@sayakpaul sayakpaul added performance Anything related to performance improvements, profiling and benchmarking torch.compile labels Apr 30, 2025
@a-r-r-o-w a-r-r-o-w merged commit c865115 into main Apr 30, 2025
15 of 16 checks passed
@a-r-r-o-w a-r-r-o-w deleted the improve-hunyuan-compile-support branch April 30, 2025 05:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking torch.compile

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants