@@ -19,7 +19,6 @@ def __init__(
1919 qkv_bias = False ,
2020 use_rope = False ,
2121 max_len = 10000 ,
22- use_flash_attention = True ,
2322 ):
2423 super ().__init__ ()
2524
@@ -31,7 +30,6 @@ def __init__(
3130 self .d_in = d_in
3231 self .use_rope = use_rope
3332 self .rope_dim = self .head_dim
34- self .use_flash_attention = use_flash_attention
3533
3634 self .qkv = nn .Linear (d_in , 3 * d_out , bias = qkv_bias )
3735 self .proj = nn .Linear (d_in , d_out )
@@ -45,6 +43,13 @@ def __init__(
4543 embed_dim = d_out , max_len = max_len
4644 )
4745
46+ self .sdp_backends = [
47+ SDPBackend .FLASH_ATTENTION ,
48+ SDPBackend .EFFICIENT_ATTENTION ,
49+ SDPBackend .CUDNN_ATTENTION ,
50+ SDPBackend .MATH ,
51+ ]
52+
4853 def forward (self , x ):
4954 batch_size , num_tokens , embed_dim = x .shape
5055
@@ -74,17 +79,7 @@ def forward(self, x):
7479
7580 use_dropout = 0.0 if not self .training else self .dropout
7681
77- if self .use_flash_attention :
78- with sdpa_kernel (SDPBackend .FLASH_ATTENTION ):
79- context_vec = nn .functional .scaled_dot_product_attention (
80- queries ,
81- keys ,
82- values ,
83- attn_mask = None ,
84- dropout_p = use_dropout ,
85- is_causal = True ,
86- )
87- else :
82+ with sdpa_kernel (self .sdp_backends , set_priority = True ):
8883 context_vec = nn .functional .scaled_dot_product_attention (
8984 queries ,
9085 keys ,
@@ -290,15 +285,13 @@ def __init__(
290285 alpha = 0.1 ,
291286 beta = 0.1 ,
292287 use_rope = False ,
293- use_flash_attention = True ,
294288 ):
295289 super (ConformerBlock , self ).__init__ ()
296290 self .feed_forward_residual_factor = feed_forward_residual_factor
297291 self .use_deepnorm = use_deepnorm
298292 self .alpha = alpha
299293 self .beta = beta
300294 self .use_rope = use_rope
301- self .use_flash_attention = use_flash_attention
302295
303296 self .ff1 = FeedForwardBlock (embed_dim , feed_forward_expansion_factor , dropout )
304297 self .attention = MHAPyTorchScaledDotProduct (
@@ -307,7 +300,6 @@ def __init__(
307300 num_heads = num_heads ,
308301 dropout = dropout ,
309302 use_rope = use_rope ,
310- use_flash_attention = self .use_flash_attention ,
311303 )
312304 self .conv_block = ConvBlock (embed_dim , conv_kernel_size , dropout )
313305 self .ff2 = FeedForwardBlock (embed_dim , feed_forward_expansion_factor , dropout )
@@ -399,7 +391,6 @@ def __init__(
399391 use_rope : bool ,
400392 num_patches : int ,
401393 patch_size : Tuple [int , int ] | None = None ,
402- use_flash_attention : bool = True ,
403394 ):
404395 super (Conformer , self ).__init__ ()
405396 self .embed_dim = embed_dim
@@ -414,7 +405,6 @@ def __init__(
414405 self .use_deepnorm = use_deepnorm
415406 self .use_rope = use_rope
416407 self .num_patches = num_patches
417- self .use_flash_attention = use_flash_attention
418408
419409 self .input_dropout = nn .Dropout (input_dropout )
420410
@@ -437,7 +427,6 @@ def __init__(
437427 alpha = self .alpha_deepnorm ,
438428 beta = self .beta_deepnorm ,
439429 use_rope = self .use_rope ,
440- use_flash_attention = self .use_flash_attention ,
441430 )
442431 for _ in range (depth )
443432 ]
0 commit comments