Skip to content

Conversation

@aviveise
Copy link

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aviveise, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors core components of the wan2.2s2v model within the diffsynth framework. The primary goal is to introduce a 'tea cache' for performance optimization and to improve code clarity through more descriptive variable naming. Additionally, the WanVideoPipeline has been enhanced to support an offline_preprocessing mode, providing greater flexibility for different data processing scenarios. These changes collectively aim to boost the model's efficiency and make the codebase easier to understand and extend.

Highlights

  • Refactoring of Variable Names: Generic tensor variable names like x and context have been systematically renamed to more descriptive latents/hidden_states and encoder_hidden_states across multiple model and distributed processing files, enhancing code readability and maintainability.
  • Tea Cache Integration: A new tea_cache mechanism has been introduced and integrated into the WanS2VModel and its associated pipeline functions. This cache aims to optimize computation by storing and reusing intermediate states, potentially improving performance.
  • Offline Preprocessing Support: The WanVideoPipeline now supports an offline_preprocessing flag, which allows for dynamic configuration of pipeline units. This enables specialized workflows with new units like WanVideoUnit_InputVideoEmbedderPassThrough and WanVideoUnit_ImageEmbedderFusingOnly.
  • Unified Sequence Parallelism Updates: The WanS2VModel has been updated to better integrate with unified sequence parallelism, including logic for chunking and gathering hidden states. A debug print statement was also added to indicate USP activation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'tea cache' mechanism for wan2.2s2v, refactors variable names for better clarity (e.g., x to latents/hidden_states), and adds support for offline preprocessing in the WanVideoPipeline. While the refactoring improves readability, I've identified several critical bugs. These include using undefined variables after renaming and passing incorrect sequence lengths in checkpointed functions when using unified sequence parallelism, which will likely cause runtime errors or incorrect outputs. Additionally, there are a couple of medium-severity issues: some dead code that should be removed and a debug print statement left in the code. Addressing these issues is crucial for the correctness and maintainability of the new features.


if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a copy-paste error here. encoder_hidden_states is being concatenated with hidden_states, but based on the original code and the comment, it should be y (the reference image latents). Concatenating text embeddings with image latents along the channel dimension is likely incorrect.

Suggested change
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w)
hidden_states = torch.cat([hidden_states, y], dim=1) # (b, c_x + c_y, f, h, w)

Comment on lines +393 to 403
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable x is used here in the arguments to torch.utils.checkpoint.checkpoint, but it is not defined in the scope of this function after the refactoring. The variable was renamed to hidden_states. You should use hidden_states instead of x.

Suggested change
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, context, t_mod, freqs,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, context, t_mod, freqs,
use_reentrant=False,
)

Comment on lines +580 to 610
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
hidden_states,
use_reentrant=False,
)
elif use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
hidden_states,
encoder_hidden_states,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
hidden_states,
use_reentrant=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When using gradient checkpointing with unified sequence parallel (USP), you are passing seq_len_x to after_transformer_block. However, seq_len_x is the local sequence length for the current parallel rank. The after_transformer_block function performs an all_gather and requires the global sequence length (seq_len_x_global) to correctly process the gathered tensor. The non-checkpointed path correctly uses seq_len_x_global. This should be corrected for the checkpointed paths as well to avoid errors or incorrect behavior. Also, use_unified_sequence_parallel should be passed to after_transformer_block.

                if use_gradient_checkpointing_offload:
                    with torch.autograd.graph.save_on_cpu():
                        hidden_states = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(block),
                            hidden_states,
                            encoder_hidden_states,
                            t_mod,
                            seq_len_x,
                            pre_compute_freqs[0],
                            use_reentrant=False,
                        )
                        hidden_states = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel)),
                            hidden_states,
                            use_reentrant=False,
                        )
                elif use_gradient_checkpointing:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(block),
                        hidden_states,
                        encoder_hidden_states,
                        t_mod,
                        seq_len_x,
                        pre_compute_freqs[0],
                        use_reentrant=False,
                    )
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel)),
                        hidden_states,
                        use_reentrant=False,
                    )

def process(self, pipe: WanVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel"):
if pipe.use_unified_sequence_parallel:
print("use_unified_sequence_parallel true")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging and should be removed from production code.

@Artiprocher
Copy link
Collaborator

@aviveise Thanks for your contribution. Due to the substantial amount of code changes required, we will rewrite your code, which will be completed within 1–2 weeks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants