2727
2828
2929# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
30- _flash_fn = None
31- _flash_varlen_fn = None
30+ _fdma_fn = None
31+ _fdma_varlen_fn = None
3232_pad_fn = None
3333_unpad_fn = None
34+ _create_mask_fn = None
3435
3536# function that processes kwargs, generalized to handle any supported kwarg within the function
3637_process_flash_kwargs_fn = None
@@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]):
5354 """
5455 is_fdma = is_flash_dmattn_available ()
5556
56- pad_input , unpad_input = _pad_input , _unpad_input
57-
5857 if (implementation == "flash_dmattn" and is_fdma ) or (implementation is None and is_fdma ):
5958 from flash_dmattn import flash_dmattn_func , flash_dmattn_varlen_func
6059 from flash_dmattn .utils .padding import pad_input , unpad_input
60+ from flash_dmattn .utils .mask import create_mask
6161
62- return flash_dmattn_func , flash_dmattn_varlen_func , pad_input , unpad_input
62+ return flash_dmattn_func , flash_dmattn_varlen_func , pad_input , unpad_input , create_mask
6363
6464
6565def _lazy_define_process_function (flash_function ):
@@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc
9090 NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
9191 work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
9292 """
93- global _flash_fn , _flash_varlen_fn , _pad_fn , _unpad_fn
94- if force_import or any (k is None for k in [_flash_fn , _flash_varlen_fn , _pad_fn , _unpad_fn ]):
95- _flash_fn , _flash_varlen_fn , _pad_fn , _unpad_fn = _lazy_imports (implementation )
96-
93+ global _fdma_fn , _fdma_varlen_fn , _pad_fn , _unpad_fn , _create_mask_fn
94+ if force_import or any (k is None for k in [_fdma_fn , _fdma_varlen_fn , _pad_fn , _unpad_fn , _create_mask_fn ]):
95+ _fdma_fn , _fdma_varlen_fn , _pad_fn , _unpad_fn , _create_mask_fn = _lazy_imports (implementation )
96+
9797 global _process_flash_kwargs_fn
9898 if force_import or _process_flash_kwargs_fn is None :
99- _process_flash_kwargs_fn = _lazy_define_process_function (_flash_varlen_fn )
99+ _process_flash_kwargs_fn = _lazy_define_process_function (_fdma_varlen_fn )
100100
101- return (_flash_fn , _flash_varlen_fn , _pad_fn , _unpad_fn ), _process_flash_kwargs_fn
101+ return (_fdma_fn , _fdma_varlen_fn , _pad_fn , _unpad_fn , _create_mask_fn ), _process_flash_kwargs_fn
102102
103103
104104def _index_first_axis (tensor , indices ):
@@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices):
113113 return reshaped_tensor [indices ]
114114
115115
116- def _unpad_input (hidden_states , attention_mask , unused_mask = None ):
117- """
118- unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
119-
120- Arguments:
121- hidden_states: (batch, seqlen, ...)
122- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
123- unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
124-
125- Return:
126- hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
127- indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
128- cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
129- max_seqlen_in_batch: int
130- seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
131- """
132- all_masks = (attention_mask + unused_mask ) if unused_mask is not None else attention_mask
133- seqlens_in_batch = all_masks .sum (dim = - 1 , dtype = torch .int32 )
134- used_seqlens_in_batch = attention_mask .sum (dim = - 1 , dtype = torch .int32 )
135- indices = torch .nonzero (all_masks .flatten (), as_tuple = False ).flatten ()
136- max_seqlen_in_batch = seqlens_in_batch .max ().item ()
137- cu_seqlens = F .pad (torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
138-
139- return (
140- _index_first_axis (hidden_states , indices ),
141- indices ,
142- cu_seqlens ,
143- max_seqlen_in_batch ,
144- used_seqlens_in_batch ,
145- )
146-
147-
148- def _pad_input (hidden_states , indices , batch , seqlen ):
149- """
150- pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
151-
152- Arguments:
153- hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
154- indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
155- batch: int, batch size for the padded sequence.
156- seqlen: int, maximum sequence length for the padded sequence.
157-
158- Return:
159- hidden_states: (batch, seqlen, ...)
160- """
161- dim = hidden_states .shape [1 :]
162- output = torch .zeros ((batch * seqlen ), * dim , device = hidden_states .device , dtype = hidden_states .dtype )
163- output [indices ] = hidden_states
164- return output .view (batch , seqlen , * dim )
165-
166-
167116def _get_unpad_data (attention_mask : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , int ]:
168117 """
169118 Retrieves indexing data required to repad unpadded (ragged) tensors.
@@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward(
527476 "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None."
528477 )
529478
530- (flash_fn , flash_varlen_fn , pad_fn , unpad_fn ), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention (implementation )
479+ (fdma_fn , fdma_varlen_fn , pad_fn , unpad_fn , create_mask_fn ), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention (implementation )
531480
532481 # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
533482 query_states , key_states , value_states , attention_bias = fdma_peft_integration_check (
@@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward(
546495 ** kwargs ,
547496 )
548497
549- # We will use `flash_varlen_fn ` to prevent cross-example attention and also allow padding free approach under two cases:
498+ # We will use `fdma_varlen_fn ` to prevent cross-example attention and also allow padding free approach under two cases:
550499 # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
551500 # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
552- # use `flash_varlen_fn ` knowing we already have all necessary the kwargs.
501+ # use `fdma_varlen_fn ` knowing we already have all necessary the kwargs.
553502 #
554503 # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
555504 # See #39121 for more information.
@@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward(
569518 if "mps" in str (q .device ):
570519 cu_seq_lens_k = cu_seq_lens_k .clone ()
571520
572- out_unpad = flash_varlen_fn (
521+ out_unpad = fdma_varlen_fn (
573522 q ,
574523 k ,
575524 v ,
@@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward(
600549 if "mps" in str (q .device ):
601550 cu_seq_lens_k = cu_seq_lens_k .clone ()
602551
603- out = flash_varlen_fn (
552+ out = fdma_varlen_fn (
604553 q ,
605554 k ,
606555 v ,
@@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward(
624573 and window_size is not None
625574 and key_length > window_size
626575 ):
627- min_dtype = torch .finfo (query_states .dtype ).min
628- if attention_mask is not None :
629- topk_values , topk_indices = torch .topk (
630- attention_bias .masked_fill (~ attention_mask , min_dtype ).detach (),
631- window_size , dim = - 1 , largest = True , sorted = False
632- )
633- attention_mask = torch .zeros_like (
634- attention_bias , dtype = torch .bool , device = attention_bias .device
635- ).scatter_ (- 1 , topk_indices , topk_values != min_dtype )
636- else :
637- topk_values , topk_indices = torch .topk (
638- attention_bias .detach (), window_size , dim = - 1 , largest = True , sorted = False
639- )
640- attention_mask = torch .zeros_like (
641- attention_bias , dtype = torch .bool , device = attention_bias .device
642- ).scatter_ (- 1 , topk_indices , topk_values != min_dtype )
576+ attention_mask = create_mask_fn (
577+ attention_bias ,
578+ attention_mask ,
579+ batch_size = query_states .size (0 ),
580+ query_len = query_length ,
581+ key_len = key_length ,
582+ window_size = window_size ,
583+ min_dtype = torch .finfo (attention_bias .dtype ).min ,
584+ )
643585
644- out = flash_fn (
586+ out = fdma_fn (
645587 query_states ,
646588 key_states ,
647589 value_states ,
0 commit comments