Skip to content

Commit 35f2005

Browse files
[FlashAttention] Sync from upstream tensor desc implementation (part 1) (#4467)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent d46b7d2 commit 35f2005

File tree

1 file changed

+39
-30
lines changed

1 file changed

+39
-30
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from triton_kernels_benchmark import flash_attention_benchmark
88

9+
# FIXME: Revert temporary source code modification done in last commit of PR #4399.
10+
911

1012
# pylint: disable=unused-argument
1113
@triton.jit
1214
def _attn_fwd_inner(acc, l_i, m_i, q, #
13-
K_desc, V_desc, #
14-
start_m, qk_scale, #
15+
desc_k, desc_v, #
16+
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
1517
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
1618
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
1719
N_CTX: tl.constexpr):
@@ -24,13 +26,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
2426
# causal = False
2527
else:
2628
lo, hi = 0, N_CTX
27-
off_k = lo
28-
off_v = lo
29+
offsetk_y = offset_y + lo
30+
offsetv_y = offset_y + lo
2931
# loop over k, v and update accumulator
3032
for start_n in range(lo, hi, BLOCK_N):
3133
start_n = tl.multiple_of(start_n, BLOCK_N)
3234
# -- compute qk ----
33-
k = K_desc.load([0, off_k])
35+
k = desc_k.load([0, offsetk_y])
3436
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
3537
qk += tl.dot(q, k)
3638
if STAGE == 2:
@@ -43,18 +45,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
4345
qk = qk * qk_scale - m_ij[:, None]
4446
p = tl.math.exp2(qk)
4547
l_ij = tl.sum(p, 1)
46-
# -- update m_i and l_i
48+
# -- compute correction factor
4749
alpha = tl.math.exp2(m_i - m_ij)
4850
l_i = l_i * alpha + l_ij
4951
# -- update output accumulator --
5052
acc = acc * alpha[:, None]
51-
# update acc
52-
v = V_desc.load([off_v, 0])
53+
# prepare p and v for the dot
54+
v = desc_v.load([offsetv_y, 0])
55+
# note that this non transposed v for FP8 is only supported on Blackwell
5356
acc += tl.dot(p.to(tl.float16), v)
5457
# update m_i and l_i
58+
# place this at the end of the loop to reduce register pressure
5559
m_i = m_ij
56-
off_v += BLOCK_N
57-
off_k += BLOCK_N
60+
offsetk_y += BLOCK_N
61+
offsetv_y += BLOCK_N
5862
return acc, l_i, m_i
5963

6064

@@ -75,25 +79,28 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
7579
BLOCK_N: tl.constexpr, #
7680
STAGE: tl.constexpr #
7781
): # pylint: disable=unused-argument
78-
82+
dtype = tl.float16
83+
tl.static_assert(BLOCK_N <= BLOCK_DMODEL)
7984
start_m = tl.program_id(2)
8085
off_z = tl.program_id(0)
8186
off_h = tl.program_id(1)
82-
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
87+
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
8388
if N_CTX <= 512:
8489
start_m = tl.program_id(0)
8590
off_z = tl.program_id(2)
86-
qvk_offset = off_z.to(tl.int64) * stride_qh
91+
offset_y = off_z * N_CTX
92+
93+
y_dim = Z * H * N_CTX
94+
desc_q = tl.make_tensor_descriptor(Q, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
95+
block_shape=[BLOCK_M, BLOCK_DMODEL])
96+
desc_v = tl.make_tensor_descriptor(V, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
97+
block_shape=[BLOCK_N, BLOCK_DMODEL])
98+
desc_k = tl.make_tensor_descriptor(K, shape=[BLOCK_DMODEL, y_dim], strides=[1, BLOCK_DMODEL],
99+
block_shape=[BLOCK_DMODEL, BLOCK_N])
100+
desc_o = tl.make_tensor_descriptor(Out, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
101+
block_shape=[BLOCK_M, BLOCK_DMODEL])
87102

88-
# tensor descriptors
89-
Q_desc = tl.make_tensor_descriptor(base=Q + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk),
90-
block_shape=(BLOCK_M, BLOCK_DMODEL))
91-
V_desc = tl.make_tensor_descriptor(base=V + qvk_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn),
92-
block_shape=(BLOCK_N, BLOCK_DMODEL))
93-
K_desc = tl.make_tensor_descriptor(base=K + qvk_offset, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn),
94-
block_shape=(BLOCK_DMODEL, BLOCK_N))
95-
O_desc = tl.make_tensor_descriptor(base=Out + qvk_offset, shape=(N_CTX, BLOCK_DMODEL),
96-
strides=(stride_om, stride_on), block_shape=(BLOCK_M, BLOCK_DMODEL))
103+
qo_offset_y = offset_y + start_m * BLOCK_M
97104
# initialize offsets
98105
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
99106
offs_n = tl.arange(0, BLOCK_N)
@@ -105,27 +112,29 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
105112
qk_scale = sm_scale
106113
qk_scale *= 1.44269504 # 1/log(2)
107114
# load q: it will stay in SRAM throughout
108-
q = Q_desc.load([start_m * BLOCK_M, 0])
115+
q = desc_q.load([qo_offset_y, 0])
109116
# stage 1: off-band
110-
# For causal = True, STAGE = 3, the kernel gets 1 as its STAGE
111-
# For causal = False, STAGE = 1, the kernel gets 3 as its STAGE
117+
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
118+
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
112119
if STAGE & 1:
113-
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_desc, V_desc, #
114-
start_m, qk_scale, #
120+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, #
121+
desc_k, desc_v, #
122+
offset_y, dtype, start_m, qk_scale, #
115123
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
116124
4 - STAGE, offs_m, offs_n, N_CTX #
117125
)
118126
# stage 2: on-band
119127
if STAGE & 2:
120-
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_desc, V_desc, #
121-
start_m, qk_scale, #
128+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, #
129+
desc_k, desc_v, #
130+
offset_y, dtype, start_m, qk_scale, #
122131
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
123132
2, offs_m, offs_n, N_CTX #
124133
)
125134
# epilogue
126135
m_i += tl.math.log2(l_i)
127136
acc = acc / l_i[:, None]
128-
O_desc.store([start_m * BLOCK_M, 0], acc.to(Out.type.element_ty))
137+
desc_o.store([qo_offset_y, 0], acc.to(Out.type.element_ty))
129138

130139

131140
def get_benchmark(

0 commit comments

Comments
 (0)