Skip to content

Commit b381e8a

Browse files
authored
[FlashAttn Backward] Use the tensor descriptor to replace the regular pointer in flash attn bwd kernel to improve the performance. (#5150)
Use the tensor descriptor to replace the regular pointer in flash attn bwd kernel to improve the performance. The performance of the flash attn bwd kernel with tensor descriptor. ``` attn-performance: Z H N_CTX D_HEAD CAUSAL MODE Triton-GB/s XeTLA-GB/s Triton-GB/s-min XeTLA-GB/s-min Triton-GB/s-max XeTLA-GB/s-max Triton-TFlops XeTLA-TFlops Triton-TFlops-min XeTLA-TFlops-min Triton-TFlops-max XeTLA-TFlops-max Triton-CV XeTLA-CV 0 1 16 16384 128 True bwd 1.813470 6.079903 1.800877 6.061080 1.913038 6.102785 14.855948 49.806569 14.752785 49.652369 15.671610 49.994019 0.011349 0.002924 1 1 32 16384 64 True bwd 5.662760 7.092534 5.655727 7.080177 5.677753 7.101936 46.389333 58.102041 46.331715 58.000813 46.512155 58.179061 0.000423 0.001173 2 2 16 8192 128 True bwd 3.556367 21.238178 3.502044 21.034624 3.771544 21.563326 14.566881 86.991577 14.344373 86.157821 15.448244 88.323383 0.007122 0.002600 3 2 32 8192 64 True bwd 11.088122 13.702544 11.052767 13.634914 11.109474 13.737674 45.416948 56.125619 45.272133 55.848608 45.504404 56.269511 0.001150 0.002095 4 4 16 4096 128 True bwd 7.178239 40.728465 6.932883 40.610607 7.459007 41.030118 14.701033 83.411897 14.198544 83.170523 15.276046 84.029681 0.026862 0.002404 5 4 32 4096 64 True bwd 21.158215 37.828013 21.034730 37.661662 21.228383 38.144256 43.332024 77.471771 43.079127 77.131084 43.475728 78.119436 0.002859 0.003174 6 8 16 2048 128 True bwd 14.492702 110.975475 14.273039 110.562635 14.649724 111.485411 14.840526 113.638886 14.615592 113.216138 15.001318 114.161061 0.008343 0.003146 7 8 32 2048 64 True bwd 38.583480 76.129790 38.397422 75.750483 38.785155 76.279489 39.509484 77.956905 39.318960 77.568494 39.715999 78.110197 0.002182 0.001470 8 16 16 1024 128 True bwd 27.843392 213.751071 27.634647 208.796489 27.954944 215.114581 14.255817 109.440548 14.148939 106.903802 14.312931 110.138665 0.002909 0.005026 9 16 32 1024 64 True bwd 66.995507 160.982706 66.901201 160.130729 67.134647 161.655129 34.301700 82.423145 34.253415 81.986933 34.372939 82.767426 0.000977 0.002851 10 32 16 512 128 True bwd 51.142101 395.204357 50.888779 385.470453 51.357985 402.292739 13.092378 101.172315 13.027527 98.680436 13.147644 102.986941 0.002322 0.010136 11 32 32 512 64 True bwd 104.514001 268.081593 103.852827 257.698691 104.737151 275.941062 26.755584 68.628888 26.586324 65.970865 26.812711 70.640912 0.001417 0.022733 12 4 48 1024 64 True bwd 65.527128 139.171494 65.255945 137.801301 65.771681 140.283978 33.549890 71.255805 33.411044 70.554266 33.675101 71.825397 0.001954 0.004190 ``` The geomean tflops = 24.2. --------- Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 3e0e565 commit b381e8a

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,18 @@ def _attn_bwd_dkdv(dk, dv, #
213213
# Filled in by the wrapper.
214214
start_n, start_m, num_steps, #
215215
MASK: tl.constexpr):
216-
offs_m = start_m + tl.arange(0, BLOCK_M1)
217216
offs_n = start_n + tl.arange(0, BLOCK_N1)
218-
offs_k = tl.arange(0, HEAD_DIM)
219-
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
220-
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
217+
qT_desc = tl.make_tensor_descriptor(Q, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
218+
block_shape=[HEAD_DIM, BLOCK_M1])
219+
220+
do_desc = tl.make_tensor_descriptor(DO, shape=[N_CTX, HEAD_DIM], strides=[stride_tok, stride_d],
221+
block_shape=[BLOCK_M1, HEAD_DIM])
221222
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
222223
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
223224
curr_m = start_m
224225
step_m = BLOCK_M1
225226
for blk_idx in range(num_steps):
226-
qT = tl.load(qT_ptrs)
227+
qT = qT_desc.load([0, start_m + blk_idx * step_m])
227228
# Load m before computing qk to reduce pipeline stall.
228229
offs_m = curr_m + tl.arange(0, BLOCK_M1)
229230
m = tl.load(M + offs_m)
@@ -233,7 +234,7 @@ def _attn_bwd_dkdv(dk, dv, #
233234
if MASK:
234235
mask = (offs_m[None, :] >= offs_n[:, None])
235236
pT = tl.where(mask, pT, 0.0)
236-
do = tl.load(do_ptrs)
237+
do = do_desc.load([start_m + blk_idx * step_m, 0])
237238
# Compute dV.
238239
ppT = pT
239240
ppT = ppT.to(tl.float16)
@@ -247,8 +248,6 @@ def _attn_bwd_dkdv(dk, dv, #
247248
dk += tl.dot(dsT, tl.trans(qT))
248249
# Increment pointers.
249250
curr_m += step_m
250-
qT_ptrs += step_m * stride_tok
251-
do_ptrs += step_m * stride_tok
252251
return dk, dv
253252

254253

@@ -267,19 +266,20 @@ def _attn_bwd_dq(dq, q, K, V, #
267266
start_m, start_n, num_steps, #
268267
MASK: tl.constexpr):
269268
offs_m = start_m + tl.arange(0, BLOCK_M2)
270-
offs_n = start_n + tl.arange(0, BLOCK_N2)
271-
offs_k = tl.arange(0, HEAD_DIM)
272-
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
273-
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
269+
kT_desc = tl.make_tensor_descriptor(K, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
270+
block_shape=[HEAD_DIM, BLOCK_N2])
271+
272+
vT_desc = tl.make_tensor_descriptor(V, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
273+
block_shape=[HEAD_DIM, BLOCK_N2])
274274
# D (= delta) is pre-divided by ds_scale.
275275
Di = tl.load(D + offs_m)
276276
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
277277
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
278278
curr_n = start_n
279279
step_n = BLOCK_N2
280280
for blk_idx in range(num_steps):
281-
kT = tl.load(kT_ptrs)
282-
vT = tl.load(vT_ptrs)
281+
kT = kT_desc.load([0, start_n + blk_idx * step_n])
282+
vT = vT_desc.load([0, start_n + blk_idx * step_n])
283283
qk = tl.dot(q, kT)
284284
p = tl.math.exp2(qk - m)
285285
# Autoregressive masking.
@@ -296,8 +296,6 @@ def _attn_bwd_dq(dq, q, K, V, #
296296
dq += tl.dot(ds, tl.trans(kT))
297297
# Increment pointers.
298298
curr_n += step_n
299-
kT_ptrs += step_n * stride_tok
300-
vT_ptrs += step_n * stride_tok
301299
return dq
302300

303301

0 commit comments

Comments
 (0)