@@ -123,11 +123,10 @@ def apply_rotary_emb(
123123            query  =  apply_rotary_emb (query , * rotary_emb )
124124            key  =  apply_rotary_emb (key , * rotary_emb )
125125
126-         if  self ._attention_backend  ==  "_native_flash-flash_varlen" :
127-             if  not  self .is_cross_attention :
128-                 self ._attention_backend  =  "_native_flash" 
129-             else :
130-                 self ._attention_backend  =  "flash_varlen" 
126+         if  not  self .is_cross_attention :
127+             attention_backend  =  "_native_flash" 
128+         else :
129+             attention_backend  =  "flash_varlen" 
131130
132131        # I2V task 
133132        hidden_states_img  =  None 
@@ -145,7 +144,7 @@ def apply_rotary_emb(
145144                attn_mask = None ,
146145                dropout_p = 0.0 ,
147146                is_causal = False ,
148-                 backend = self . _attention_backend ,
147+                 backend = attention_backend ,
149148            )
150149            hidden_states_img  =  hidden_states_img .flatten (2 , 3 )
151150            hidden_states_img  =  hidden_states_img .type_as (query )
@@ -157,7 +156,7 @@ def apply_rotary_emb(
157156            attn_mask = attention_mask ,
158157            dropout_p = 0.0 ,
159158            is_causal = False ,
160-             backend = self . _attention_backend ,
159+             backend = attention_backend ,
161160        )
162161        hidden_states  =  hidden_states .flatten (2 , 3 )
163162        hidden_states  =  hidden_states .type_as (query )
0 commit comments