66import numpy as np
77import torch
88import torch .nn .functional as F
9-
109from typing import Optional
11- from torch import nn
10+ from packaging . version import Version
1211
1312from megatron import get_timers , get_args , get_retro_args , core , get_num_microbatches
1413from megatron .utils import print_rank_0
3635 rearrange = None
3736
3837try :
39- from flash_attn .flash_attn_interface import flash_attn_unpadded_func
38+ import flash_attn as _flash_attn
39+ if Version (getattr (_flash_attn , "__version__" , "1" )) >= Version ("2" ):
40+ from flash_attn .flash_attn_interface import flash_attn_func
41+ FLASH_VERSION = 2
42+ else :
43+ from flash_attn .flash_attn_interface import flash_attn_unpadded_func
44+ FLASH_VERSION = 1
4045except ImportError :
41- flash_attn_unpadded_func = None
46+ FLASH_VERSION = None
4247
4348
4449""" We use the following notation throughout this file:
@@ -508,7 +513,7 @@ class FlashSelfAttention(torch.nn.Module):
508513 def __init__ (self , causal = False , softmax_scale = None , attention_dropout = 0.0 ,
509514 device = None , dtype = None ):
510515 super ().__init__ ()
511- assert flash_attn_unpadded_func is not None , ('Please install FlashAttention first, '
516+ assert FLASH_VERSION is not None , ('Please install FlashAttention first, '
512517 'e.g., with pip install flash-attn' )
513518 assert rearrange is not None , 'Please install einops first, e.g., with pip install einops'
514519 self .causal = causal
@@ -521,10 +526,31 @@ def forward(self, q, k, v):
521526 ---------
522527 q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
523528 """
524-
525529 assert all ((i .dtype in [torch .float16 , torch .bfloat16 ] for i in (q ,k ,v )))
526530 assert all ((i .is_cuda for i in (q ,k ,v )))
527531
532+ if FLASH_VERSION == 1 :
533+ return self ._forward_v1 (q ,k ,v )
534+
535+ seqlen_q , seqlen_k = q .shape [1 ], k .shape [1 ]
536+
537+ if self .training :
538+ # during training q,k,v always have same seqlen
539+ assert seqlen_k == seqlen_q
540+ is_causal = self .causal
541+ dropout_p = self .dropout_p
542+ else :
543+ # turn off FA causal mask after first inference autoregressive iteration
544+ # only on first autoregressive step q,k,v have same seqlen
545+ is_causal = self .causal and (seqlen_q == seqlen_k )
546+ dropout_p = 0
547+
548+ output = flash_attn_func (q , k , v , dropout_p ,softmax_scale = self .softmax_scale , causal = is_causal )
549+
550+ return output
551+
552+
553+ def _forward_v1 (self , q , k , v ):
528554 batch_size , seqlen_q = q .shape [0 ], q .shape [1 ]
529555 seqlen_k = k .shape [1 ]
530556
@@ -647,7 +673,7 @@ def __init__(self, init_method,
647673 self .checkpoint_core_attention = args .recompute_granularity == 'selective'
648674
649675 if self .use_flash_attn :
650- if flash_attn_unpadded_func is None :
676+ if FLASH_VERSION is None :
651677 raise ImportError ('FlashAttention is not installed, please install with '
652678 'pip install flash-attn' )
653679 assert attention_type == AttnType .self_attn , ('FlashAttention code path only supports '
@@ -882,6 +908,7 @@ def forward(self, hidden_states, attention_mask,
882908 sq , b , np , hn = query_layer .size ()
883909 # Expand kv to be compatible with flash-attn implementation
884910 # [sq, b, 1, hn] -> [sq, b, np, hn]
911+ # TODO: This should be skippable for flash 2, but getting illegal memory access.
885912 key_layer = key_layer .expand ((sq , b , np , hn ))
886913 value_layer = value_layer .expand ((sq , b , np , hn ))
887914 q , k , v = [rearrange (x , 's b ... -> b s ...' ).contiguous ()
0 commit comments