Skip to content

Commit 018694a

Browse files
authored
Merge pull request #61 from MTG/selective_flash_attention
Select flash attention automatically when available
2 parents 58d7c8f + 94386ac commit 018694a

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

src/nets/common_former.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
2020
self.proj = nn.Linear(d_in, d_out)
2121
self.dropout = dropout
2222

23+
self.sdp_backends = [
24+
SDPBackend.FLASH_ATTENTION,
25+
SDPBackend.EFFICIENT_ATTENTION,
26+
SDPBackend.CUDNN_ATTENTION,
27+
SDPBackend.MATH,
28+
]
29+
2330
def forward(self, x):
2431
batch_size, num_tokens, embed_dim = x.shape
2532

@@ -37,7 +44,7 @@ def forward(self, x):
3744

3845
use_dropout = 0.0 if not self.training else self.dropout
3946

40-
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
47+
with sdpa_kernel(self.sdp_backends, set_priority=True):
4148
context_vec = nn.functional.scaled_dot_product_attention(
4249
queries,
4350
keys,

src/nets/conformer.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __init__(
1919
qkv_bias=False,
2020
use_rope=False,
2121
max_len=10000,
22-
use_flash_attention=True,
2322
):
2423
super().__init__()
2524

@@ -31,7 +30,6 @@ def __init__(
3130
self.d_in = d_in
3231
self.use_rope = use_rope
3332
self.rope_dim = self.head_dim
34-
self.use_flash_attention = use_flash_attention
3533

3634
self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
3735
self.proj = nn.Linear(d_in, d_out)
@@ -45,6 +43,13 @@ def __init__(
4543
embed_dim=d_out, max_len=max_len
4644
)
4745

46+
self.sdp_backends = [
47+
SDPBackend.FLASH_ATTENTION,
48+
SDPBackend.EFFICIENT_ATTENTION,
49+
SDPBackend.CUDNN_ATTENTION,
50+
SDPBackend.MATH,
51+
]
52+
4853
def forward(self, x):
4954
batch_size, num_tokens, embed_dim = x.shape
5055

@@ -74,17 +79,7 @@ def forward(self, x):
7479

7580
use_dropout = 0.0 if not self.training else self.dropout
7681

77-
if self.use_flash_attention:
78-
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
79-
context_vec = nn.functional.scaled_dot_product_attention(
80-
queries,
81-
keys,
82-
values,
83-
attn_mask=None,
84-
dropout_p=use_dropout,
85-
is_causal=True,
86-
)
87-
else:
82+
with sdpa_kernel(self.sdp_backends, set_priority=True):
8883
context_vec = nn.functional.scaled_dot_product_attention(
8984
queries,
9085
keys,
@@ -290,15 +285,13 @@ def __init__(
290285
alpha=0.1,
291286
beta=0.1,
292287
use_rope=False,
293-
use_flash_attention=True,
294288
):
295289
super(ConformerBlock, self).__init__()
296290
self.feed_forward_residual_factor = feed_forward_residual_factor
297291
self.use_deepnorm = use_deepnorm
298292
self.alpha = alpha
299293
self.beta = beta
300294
self.use_rope = use_rope
301-
self.use_flash_attention = use_flash_attention
302295

303296
self.ff1 = FeedForwardBlock(embed_dim, feed_forward_expansion_factor, dropout)
304297
self.attention = MHAPyTorchScaledDotProduct(
@@ -307,7 +300,6 @@ def __init__(
307300
num_heads=num_heads,
308301
dropout=dropout,
309302
use_rope=use_rope,
310-
use_flash_attention=self.use_flash_attention,
311303
)
312304
self.conv_block = ConvBlock(embed_dim, conv_kernel_size, dropout)
313305
self.ff2 = FeedForwardBlock(embed_dim, feed_forward_expansion_factor, dropout)
@@ -399,7 +391,6 @@ def __init__(
399391
use_rope: bool,
400392
num_patches: int,
401393
patch_size: Tuple[int, int] | None = None,
402-
use_flash_attention: bool = True,
403394
):
404395
super(Conformer, self).__init__()
405396
self.embed_dim = embed_dim
@@ -414,7 +405,6 @@ def __init__(
414405
self.use_deepnorm = use_deepnorm
415406
self.use_rope = use_rope
416407
self.num_patches = num_patches
417-
self.use_flash_attention = use_flash_attention
418408

419409
self.input_dropout = nn.Dropout(input_dropout)
420410

@@ -437,7 +427,6 @@ def __init__(
437427
alpha=self.alpha_deepnorm,
438428
beta=self.beta_deepnorm,
439429
use_rope=self.use_rope,
440-
use_flash_attention=self.use_flash_attention,
441430
)
442431
for _ in range(depth)
443432
]

0 commit comments

Comments
 (0)