Skip to content

Commit 462980b

Browse files
authored
Support flash attn 2 (#72)
1 parent ebea9f2 commit 462980b

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

megatron/model/transformer.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import numpy as np
77
import torch
88
import torch.nn.functional as F
9-
109
from typing import Optional
11-
from torch import nn
10+
from packaging.version import Version
1211

1312
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
1413
from megatron.utils import print_rank_0
@@ -36,9 +35,15 @@
3635
rearrange = None
3736

3837
try:
39-
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
38+
import flash_attn as _flash_attn
39+
if Version(getattr(_flash_attn, "__version__", "1")) >= Version("2"):
40+
from flash_attn.flash_attn_interface import flash_attn_func
41+
FLASH_VERSION = 2
42+
else:
43+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
44+
FLASH_VERSION = 1
4045
except ImportError:
41-
flash_attn_unpadded_func = None
46+
FLASH_VERSION = None
4247

4348

4449
""" We use the following notation throughout this file:
@@ -508,7 +513,7 @@ class FlashSelfAttention(torch.nn.Module):
508513
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
509514
device=None, dtype=None):
510515
super().__init__()
511-
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
516+
assert FLASH_VERSION is not None, ('Please install FlashAttention first, '
512517
'e.g., with pip install flash-attn')
513518
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
514519
self.causal = causal
@@ -521,10 +526,31 @@ def forward(self, q, k, v):
521526
---------
522527
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
523528
"""
524-
525529
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
526530
assert all((i.is_cuda for i in (q,k,v)))
527531

532+
if FLASH_VERSION==1:
533+
return self._forward_v1(q,k,v)
534+
535+
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
536+
537+
if self.training:
538+
# during training q,k,v always have same seqlen
539+
assert seqlen_k == seqlen_q
540+
is_causal = self.causal
541+
dropout_p = self.dropout_p
542+
else:
543+
# turn off FA causal mask after first inference autoregressive iteration
544+
# only on first autoregressive step q,k,v have same seqlen
545+
is_causal = self.causal and (seqlen_q == seqlen_k)
546+
dropout_p = 0
547+
548+
output = flash_attn_func(q, k, v, dropout_p,softmax_scale=self.softmax_scale, causal=is_causal)
549+
550+
return output
551+
552+
553+
def _forward_v1(self, q, k, v):
528554
batch_size, seqlen_q = q.shape[0], q.shape[1]
529555
seqlen_k = k.shape[1]
530556

@@ -647,7 +673,7 @@ def __init__(self, init_method,
647673
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
648674

649675
if self.use_flash_attn:
650-
if flash_attn_unpadded_func is None:
676+
if FLASH_VERSION is None:
651677
raise ImportError('FlashAttention is not installed, please install with '
652678
'pip install flash-attn')
653679
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
@@ -882,6 +908,7 @@ def forward(self, hidden_states, attention_mask,
882908
sq, b, np, hn = query_layer.size()
883909
# Expand kv to be compatible with flash-attn implementation
884910
# [sq, b, 1, hn] -> [sq, b, np, hn]
911+
# TODO: This should be skippable for flash 2, but getting illegal memory access.
885912
key_layer = key_layer.expand((sq, b, np, hn))
886913
value_layer = value_layer.expand((sq, b, np, hn))
887914
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()

0 commit comments

Comments
 (0)