- 
                Notifications
    
You must be signed in to change notification settings  - Fork 986
 
Adding tea cache wan2.2s2v #1017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            aviveise
  wants to merge
  23
  commits into
  modelscope:main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
Smiti-AI:adding_tea_cache_wan2.2s2v
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
      
        
          +276
        
        
          −173
        
        
          
        
      
    
  
  
     Open
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            23 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      6eaef81
              
                making units not mendatory if models not available
              
              
                aviveise 5879063
              
                adding encoding only flag
              
              
                aviveise bb800bb
              
                fixing
              
              
                aviveise 9af7472
              
                fix in method
              
              
                aviveise b6ecbd9
              
                fix in method
              
              
                aviveise c4ce048
              
                adding offline preprocessing to from_pretrained method
              
              
                aviveise df68cbb
              
                adding passive input video unit
              
              
                aviveise 9187e9d
              
                adding tea cache to wan 2.2 with wan 2.1 coefficients
              
              
                aviveise 64b3508
              
                fix
              
              
                aviveise 440ce5b
              
                renaming dit inputs
              
              
                aviveise 98ffebf
              
                renaming dit inputs
              
              
                aviveise cfcf4da
              
                using dit forward
              
              
                aviveise 38d594a
              
                using dit forward
              
              
                aviveise 7aedfd9
              
                fixing usp
              
              
                aviveise b1a62a3
              
                disabling usp dit method overide
              
              
                aviveise 5aa7363
              
                adding prints for debug
              
              
                aviveise 6cb86d7
              
                print
              
              
                aviveise defe06d
              
                print
              
              
                aviveise 982f810
              
                fix in seq_len_x_global
              
              
                aviveise b1d9717
              
                fix in dit model
              
              
                aviveise 144a475
              
                removing prints
              
              
                aviveise 0e6e3e2
              
                removing prints
              
              
                aviveise 8b4c859
              
                adding fps to predict
              
              
                aviveise File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 
          
            
          
           | 
    @@ -211,7 +211,7 @@ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int | |||||||||||||||||||||||||||||||||||||||||||||||
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.gate = GateModule() | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x, context, t_mod, freqs): | ||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, hidden_states, encoder_hidden_states, t_mod, freqs): | ||||||||||||||||||||||||||||||||||||||||||||||||
| has_seq = len(t_mod.shape) == 4 | ||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_dim = 2 if has_seq else 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| # msa: multi-head self-attention mlp: multi-layer perceptron | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -222,12 +222,12 @@ def forward(self, x, context, t_mod, freqs): | |||||||||||||||||||||||||||||||||||||||||||||||
| shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), | ||||||||||||||||||||||||||||||||||||||||||||||||
| shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = x + self.cross_attn(self.norm3(x), context) | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = self.gate(x, gate_mlp, self.ffn(input_x)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| class MLP(torch.nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -244,10 +244,10 @@ def __init__(self, in_dim, out_dim, has_pos_emb=False): | |||||||||||||||||||||||||||||||||||||||||||||||
| if has_pos_emb: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x): | ||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, hidden_states): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if self.has_pos_emb: | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return self.proj(x) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = hidden_states + self.emb_pos.to(dtype=hidden_states.dtype, device=hidden_states.device) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return self.proj(hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| class Head(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -259,14 +259,14 @@ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps | |||||||||||||||||||||||||||||||||||||||||||||||
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x, t_mod): | ||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, hidden_states, t_mod): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(t_mod.shape) == 3: | ||||||||||||||||||||||||||||||||||||||||||||||||
| shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = (self.head(self.norm(hidden_states) * (1 + scale.squeeze(2)) + shift.squeeze(2))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = (self.head(self.norm(x) * (1 + scale) + shift)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = (self.head(self.norm(hidden_states) * (1 + scale) + shift)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| class WanModel(torch.nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
          
            
          
           | 
    @@ -354,9 +354,9 @@ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): | |||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, | ||||||||||||||||||||||||||||||||||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| timestep: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| context: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| clip_feature: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| y: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| use_gradient_checkpointing: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -366,20 +366,20 @@ def forward(self, | |||||||||||||||||||||||||||||||||||||||||||||||
| t = self.time_embedding( | ||||||||||||||||||||||||||||||||||||||||||||||||
| sinusoidal_embedding_1d(self.freq_dim, timestep)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| context = self.text_embedding(context) | ||||||||||||||||||||||||||||||||||||||||||||||||
| context = self.text_embedding(encoder_hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| if self.has_image_input: | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w) | ||||||||||||||||||||||||||||||||||||||||||||||||
| clip_embdding = self.img_emb(clip_feature) | ||||||||||||||||||||||||||||||||||||||||||||||||
| context = torch.cat([clip_embdding, context], dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| x, (f, h, w) = self.patchify(x) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, (f, h, w) = self.patchify(hidden_states) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| freqs = torch.cat([ | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ], dim=-1).reshape(f * h * w, 1, -1).to(hidden_states.device) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| def create_custom_forward(module): | ||||||||||||||||||||||||||||||||||||||||||||||||
| def custom_forward(*inputs): | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
        
          
        
         | 
    @@ -390,23 +390,23 @@ def custom_forward(*inputs): | |||||||||||||||||||||||||||||||||||||||||||||||
| if self.training and use_gradient_checkpointing: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if use_gradient_checkpointing_offload: | ||||||||||||||||||||||||||||||||||||||||||||||||
| with torch.autograd.graph.save_on_cpu(): | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = torch.utils.checkpoint.checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||
| create_custom_forward(block), | ||||||||||||||||||||||||||||||||||||||||||||||||
| x, context, t_mod, freqs, | ||||||||||||||||||||||||||||||||||||||||||||||||
| use_reentrant=False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = torch.utils.checkpoint.checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = torch.utils.checkpoint.checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||
| create_custom_forward(block), | ||||||||||||||||||||||||||||||||||||||||||||||||
| x, context, t_mod, freqs, | ||||||||||||||||||||||||||||||||||||||||||||||||
| use_reentrant=False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
         
      Comment on lines
    
      +393
     to 
      403
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable  
        Suggested change
       
    
  | 
||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = block(x, context, t_mod, freqs) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = block(hidden_states, context, t_mod, freqs) | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| x = self.head(x, t) | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = self.unpatchify(x, (f, h, w)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.head(hidden_states, t) | ||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states = self.unpatchify(hidden_states, (f, h, w)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return hidden_states | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||
| def state_dict_converter(): | ||||||||||||||||||||||||||||||||||||||||||||||||
| 
          
            
          
           | 
    ||||||||||||||||||||||||||||||||||||||||||||||||
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a copy-paste error here.
encoder_hidden_statesis being concatenated withhidden_states, but based on the original code and the comment, it should bey(the reference image latents). Concatenating text embeddings with image latents along the channel dimension is likely incorrect.