File tree Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -873,6 +873,7 @@ def _native_flash_attention(
873873 query : torch .Tensor ,
874874 key : torch .Tensor ,
875875 value : torch .Tensor ,
876+ attn_mask : Optional [torch .Tensor ] = None ,
876877 dropout_p : float = 0.0 ,
877878 is_causal : bool = False ,
878879 scale : Optional [float ] = None ,
@@ -884,7 +885,7 @@ def _native_flash_attention(
884885 query = query ,
885886 key = key ,
886887 value = value ,
887- attn_mask = None , # not supported
888+ attn_mask = attn_mask ,
888889 dropout_p = dropout_p ,
889890 is_causal = is_causal ,
890891 scale = scale ,
Original file line number Diff line number Diff line change @@ -123,6 +123,12 @@ def apply_rotary_emb(
123123 query = apply_rotary_emb (query , * rotary_emb )
124124 key = apply_rotary_emb (key , * rotary_emb )
125125
126+ if self ._attention_backend == "_native_flash-flash_varlen" :
127+ if not self .is_cross_attention :
128+ self ._attention_backend = "_native_flash"
129+ else :
130+ self ._attention_backend = "flash_varlen"
131+
126132 # I2V task
127133 hidden_states_img = None
128134 if encoder_hidden_states_img is not None :
@@ -153,7 +159,6 @@ def apply_rotary_emb(
153159 is_causal = False ,
154160 backend = self ._attention_backend ,
155161 )
156-
157162 hidden_states = hidden_states .flatten (2 , 3 )
158163 hidden_states = hidden_states .type_as (query )
159164
You can’t perform that action at this time.
0 commit comments