Skip to content

Commit 68c44f3

Browse files
[FlashAttn] Remove transpose workaround (#5230)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8a33435 commit 68c44f3

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_tensor_desc_benchmark.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
from triton_kernels_benchmark import flash_attention_benchmark
88

9-
# FIXME: Revert temporary source code modification done in last commit of PR #4399.
10-
119

1210
# pylint: disable=unused-argument
1311
@triton.jit
1412
def _attn_fwd_inner(acc, l_i, m_i, q, #
1513
desc_k, desc_v, #
1614
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
17-
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
15+
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
1816
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
1917
N_CTX: tl.constexpr):
2018
# range of values handled by this stage
@@ -32,7 +30,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
3230
for start_n in tl.range(lo, hi, BLOCK_N):
3331
start_n = tl.multiple_of(start_n, BLOCK_N)
3432
# -- compute qk ----
35-
k = desc_k.load([0, offsetk_y])
33+
k = desc_k.load([offsetk_y, 0]).T
3634
qk = tl.dot(q, k)
3735
if STAGE == 2:
3836
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -95,8 +93,8 @@ def _attn_fwd_with_tensor_desc(Q, K, V, sm_scale, M, Out, #
9593
block_shape=[BLOCK_M, BLOCK_DMODEL])
9694
desc_v = tl.make_tensor_descriptor(V, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
9795
block_shape=[BLOCK_N, BLOCK_DMODEL])
98-
desc_k = tl.make_tensor_descriptor(K, shape=[BLOCK_DMODEL, y_dim], strides=[1, BLOCK_DMODEL],
99-
block_shape=[BLOCK_DMODEL, BLOCK_N])
96+
desc_k = tl.make_tensor_descriptor(K, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
97+
block_shape=[BLOCK_N, BLOCK_DMODEL])
10098
desc_o = tl.make_tensor_descriptor(Out, shape=[y_dim, BLOCK_DMODEL], strides=[BLOCK_DMODEL, 1],
10199
block_shape=[BLOCK_M, BLOCK_DMODEL])
102100

0 commit comments

Comments
 (0)