File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed 
src/diffusers/models/transformers Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -106,8 +106,8 @@ def forward(
106106            hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states 
107107        )
108108
109-         hidden_states  +=  gate_msa .unsqueeze (1 ) *  attn_hidden_states 
110-         encoder_hidden_states  +=  c_gate_msa .unsqueeze (1 ) *  attn_encoder_hidden_states 
109+         hidden_states  =   hidden_states   +  gate_msa .unsqueeze (1 ) *  attn_hidden_states 
110+         encoder_hidden_states  =   encoder_hidden_states   +  c_gate_msa .unsqueeze (1 ) *  attn_encoder_hidden_states 
111111
112112        # norm & modulate 
113113        norm_hidden_states  =  self .norm2 (hidden_states )
@@ -120,8 +120,8 @@ def forward(
120120        norm_hidden_states  =  torch .cat ([norm_encoder_hidden_states , norm_hidden_states ], dim = 1 )
121121        ff_output  =  self .ff (norm_hidden_states )
122122
123-         hidden_states  +=  gate_mlp .unsqueeze (1 ) *  ff_output [:, text_seq_length :]
124-         encoder_hidden_states  +=  c_gate_mlp .unsqueeze (1 ) *  ff_output [:, :text_seq_length ]
123+         hidden_states  =   hidden_states   +  gate_mlp .unsqueeze (1 ) *  ff_output [:, text_seq_length :]
124+         encoder_hidden_states  =   encoder_hidden_states   +  c_gate_mlp .unsqueeze (1 ) *  ff_output [:, :text_seq_length ]
125125
126126        if  hidden_states .dtype  ==  torch .float16 :
127127            hidden_states  =  hidden_states .clip (- 65504 , 65504 )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments