-
Notifications
You must be signed in to change notification settings - Fork 987
Adding tea cache wan2.2s2v #1017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @aviveise, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors core components of the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a 'tea cache' mechanism for wan2.2s2v, refactors variable names for better clarity (e.g., x to latents/hidden_states), and adds support for offline preprocessing in the WanVideoPipeline. While the refactoring improves readability, I've identified several critical bugs. These include using undefined variables after renaming and passing incorrect sequence lengths in checkpointed functions when using unified sequence parallelism, which will likely cause runtime errors or incorrect outputs. Additionally, there are a couple of medium-severity issues: some dead code that should be removed and a debug print statement left in the code. Addressing these issues is crucial for the correctness and maintainability of the new features.
|
|
||
| if self.has_image_input: | ||
| x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) | ||
| hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a copy-paste error here. encoder_hidden_states is being concatenated with hidden_states, but based on the original code and the comment, it should be y (the reference image latents). Concatenating text embeddings with image latents along the channel dimension is likely incorrect.
| hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w) | |
| hidden_states = torch.cat([hidden_states, y], dim=1) # (b, c_x + c_y, f, h, w) |
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, context, t_mod, freqs, | ||
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| x = torch.utils.checkpoint.checkpoint( | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, context, t_mod, freqs, | ||
| use_reentrant=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable x is used here in the arguments to torch.utils.checkpoint.checkpoint, but it is not defined in the scope of this function after the refactoring. The variable was renamed to hidden_states. You should use hidden_states instead of x.
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x = torch.utils.checkpoint.checkpoint( | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| x, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, context, t_mod, freqs, | |
| use_reentrant=False, | |
| ) |
| if use_gradient_checkpointing_offload: | ||
| with torch.autograd.graph.save_on_cpu(): | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| t_mod, | ||
| seq_len_x, | ||
| pre_compute_freqs[0], | ||
| use_reentrant=False, | ||
| ) | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), | ||
| hidden_states, | ||
| use_reentrant=False, | ||
| ) | ||
| elif use_gradient_checkpointing: | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(block), | ||
| x, | ||
| context, | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| t_mod, | ||
| seq_len_x, | ||
| pre_compute_freqs[0], | ||
| use_reentrant=False, | ||
| ) | ||
| x = torch.utils.checkpoint.checkpoint( | ||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), | ||
| x, | ||
| hidden_states, | ||
| use_reentrant=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using gradient checkpointing with unified sequence parallel (USP), you are passing seq_len_x to after_transformer_block. However, seq_len_x is the local sequence length for the current parallel rank. The after_transformer_block function performs an all_gather and requires the global sequence length (seq_len_x_global) to correctly process the gathered tensor. The non-checkpointed path correctly uses seq_len_x_global. This should be corrected for the checkpointed paths as well to avoid errors or incorrect behavior. Also, use_unified_sequence_parallel should be passed to after_transformer_block.
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel)),
hidden_states,
use_reentrant=False,
)
elif use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel)),
hidden_states,
use_reentrant=False,
)| def process(self, pipe: WanVideoPipeline): | ||
| if hasattr(pipe, "use_unified_sequence_parallel"): | ||
| if pipe.use_unified_sequence_parallel: | ||
| print("use_unified_sequence_parallel true") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
@aviveise Thanks for your contribution. Due to the substantial amount of code changes required, we will rewrite your code, which will be completed within 1–2 weeks. |
No description provided.