Skip to content

Commit f17d322

Browse files
authored
[gdpa] subtiling activation to reduce register pressure
Differential Revision: D81179333 Pull Request resolved: #364
1 parent 8f7b5a2 commit f17d322

File tree

1 file changed

+115
-57
lines changed

1 file changed

+115
-57
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 115 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS):
104104
return bufIdx, phase
105105

106106

107-
@triton.jit
108-
def _reinterpret(qk_buf, bufIdx_qk):
109-
qk_view = tlx.local_view(qk_buf, bufIdx_qk)
110-
p_view = tlx.local_reinterpret(qk_view, tl.float16)
111-
return p_view
112-
113-
114107
@triton.jit
115108
def _load_tma(
116109
bufIdx, phase, empty_bars, full_bars, buffers, desc, offset_1, offset_0, num_bytes
@@ -146,6 +139,69 @@ def _load_tma(
146139
# qk0, qk1: producers
147140
# p0, p1: sharing tmem spaces, and barriers with qk0, qk1 (consumers)
148141
# o0, o1
142+
143+
144+
@triton.jit
145+
def _add_f32x2(a, b):
146+
return tl.inline_asm_elementwise(
147+
"""
148+
{
149+
.reg .b64 ra, rb, rc;
150+
mov.b64 ra, { $2, $3 };
151+
mov.b64 rb, { $4, $5 };
152+
add.f32x2 rc, ra, rb;
153+
mov.b64 { $0, $1 }, rc;
154+
}
155+
""",
156+
"=r,=r,r,r,r,r",
157+
[a, b],
158+
dtype=tl.float32,
159+
is_pure=True,
160+
pack=2,
161+
)
162+
163+
164+
@triton.jit
165+
def _mul_f32x2(a, b):
166+
return tl.inline_asm_elementwise(
167+
"""
168+
{
169+
.reg .b64 ra, rb, rc;
170+
mov.b64 ra, { $2, $3 };
171+
mov.b64 rb, { $4, $5 };
172+
mul.f32x2 rc, ra, rb;
173+
mov.b64 { $0, $1 }, rc;
174+
}
175+
""",
176+
"=r,=r,r,r,r,r",
177+
[a, b],
178+
dtype=tl.float32,
179+
is_pure=True,
180+
pack=2,
181+
)
182+
183+
184+
@triton.jit
185+
def _fma_f32x2(a, b, c):
186+
return tl.inline_asm_elementwise(
187+
"""
188+
{
189+
.reg .b64 ra, rb, rc, rd;
190+
mov.b64 ra, { $2, $3 };
191+
mov.b64 rb, { $4, $5 };
192+
mov.b64 rc, { $6, $7 };
193+
fma.rn.f32x2 rd, ra, rb, rc;
194+
mov.b64 { $0, $1 }, rd;
195+
}
196+
""",
197+
"=r,=r,r,r,r,r,r,r",
198+
[a, b, c],
199+
dtype=tl.float32,
200+
is_pure=True,
201+
pack=2,
202+
)
203+
204+
149205
@triton.jit
150206
def tanh_approx_fp32(x):
151207
output = tl.inline_asm_elementwise(
@@ -164,7 +220,16 @@ def tanh_approx_fp32(x):
164220
# typical configuration is 3/fast_gelu
165221
@triton.jit
166222
def fast_gelu(x):
167-
return x * 0.5 * (1 + tanh_approx_fp32(0.7978845608 * x * (1.0 + 0.044715 * x * x)))
223+
# following D80750725
224+
# WAS: x * 0.5 * (1 + tanh_approx_fp32(0.7978845608 * x * (1.0 + 0.044715 * x * x))) * scaling
225+
# NOW: x * tanh((c1 * x * x + c0)*x) + x
226+
c1 = 0.0356774081
227+
c0 = 0.7978845608
228+
square = _mul_f32x2(x, x)
229+
inner = _fma_f32x2(c1, square, c0)
230+
inner = _mul_f32x2(inner, x)
231+
out = _fma_f32x2(x, tanh_approx_fp32(inner), x)
232+
return out
168233

169234

170235
@triton.autotune(
@@ -255,7 +320,7 @@ def gdpa_kernel_tma_ws_blackwell(
255320
)
256321

257322
if USE_ON_DEVICE_TMA:
258-
dtype = V.dtype.element_ty # v_dtype)
323+
dtype = V.dtype.element_ty
259324
else:
260325
dtype = tlx.dtype_of(v_desc)
261326

@@ -287,18 +352,14 @@ def gdpa_kernel_tma_ws_blackwell(
287352
consumer_release_q0 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q, arrive_count=1)
288353
consumer_release_q1 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_Q, arrive_count=1)
289354
consumer_kv = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV, arrive_count=1)
290-
# consumer_v = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV, arrive_count=1)
291355
consumer_release_kv = tlx.alloc_barriers(
292356
num_barriers=NUM_BUFFERS_KV, arrive_count=1
293357
)
294358
tlx.barrier_arrive(consumer_release_kv[0], 1)
295359
tlx.barrier_arrive(consumer_release_kv[1], 1)
296360
tlx.barrier_arrive(consumer_release_kv[2], 1)
297-
# consumer_release_v = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV, arrive_count=1)
298361

299-
# producer_qk0 == consumer_release_qk0
300362
producer_qk0 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_QK, arrive_count=1)
301-
# producer_commit_qk0 == consumer_qk0
302363
producer_commit_qk0 = tlx.alloc_barriers(
303364
num_barriers=NUM_BUFFERS_QK, arrive_count=1
304365
)
@@ -307,13 +368,9 @@ def gdpa_kernel_tma_ws_blackwell(
307368
num_barriers=NUM_BUFFERS_QK, arrive_count=1
308369
)
309370

310-
producer_o0 = tlx.alloc_barriers(
311-
num_barriers=NUM_BUFFERS_O, arrive_count=1
312-
) # only acquire for the first iteration
371+
producer_o0 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_O, arrive_count=1)
313372
producer_commit_o0 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_O, arrive_count=1)
314-
producer_o1 = tlx.alloc_barriers(
315-
num_barriers=NUM_BUFFERS_O, arrive_count=1
316-
) # only acquire for the first iteration
373+
producer_o1 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_O, arrive_count=1)
317374
producer_commit_o1 = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_O, arrive_count=1)
318375

319376
with tlx.async_tasks():
@@ -343,27 +400,31 @@ def gdpa_kernel_tma_ws_blackwell(
343400
for start_n in range(lo, hi, BLOCK_N):
344401
start_n = tl.multiple_of(start_n, BLOCK_N)
345402
# tl.device_print("default start_n", start_n)
346-
## communication channel for qk0, p0
347-
# qk in tmem, output p in tmem
348403
bufIdx = accum_cnt % NUM_BUFFERS_QK
349404
phase = (accum_cnt // NUM_BUFFERS_QK) & 1
350405
qk_view = tlx.local_view(qk0_buf, bufIdx)
351406
consumer_qk_view = tlx.local_view(producer_commit_qk0, bufIdx)
352407
# tl.device_print("default producer_commit_qk0", accum_cnt)
353408
# tl.device_print("default producer_commit_qk0_phase", phase)
354409
tlx.barrier_wait(consumer_qk_view, phase)
355-
qk0 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
356-
# ConsumerWait for qk, ProducerAcquire for p
357-
# if activation_enum_int == 3:
410+
411+
# qk_view: BLOCK_M // 2, HEAD_DIM
412+
qk_view_1st = tlx.subslice(qk_view, 0, HEAD_DIM // 2)
413+
qk0 = tlx.local_load(qk_view_1st)
358414
p0 = fast_gelu(qk0)
359-
p0 *= qk_scale
360-
if USE_ON_DEVICE_TMA:
361-
p0 = p0.to(V.dtype.element_ty) # v_dtype)
362-
else:
363-
p0 = p0.to(tlx.dtype_of(v_desc))
364-
qk_view = tlx.local_view(qk0_buf, bufIdx)
365-
p0_view = tlx.local_reinterpret(qk_view, dtype)
366-
tlx.local_store(p0_view, p0) # , tlx.storage_kind.tmem)
415+
p0 = p0.to(dtype)
416+
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
417+
tlx.local_store(p0_view, p0)
418+
419+
qk_view_2nd = tlx.subslice(
420+
qk_view, HEAD_DIM // 2, HEAD_DIM // 2
421+
)
422+
qk0 = tlx.local_load(qk_view_2nd)
423+
p0 = fast_gelu(qk0)
424+
p0 = p0.to(dtype)
425+
p0_view = tlx.local_reinterpret(qk_view_2nd, dtype)
426+
tlx.local_store(p0_view, p0)
427+
367428
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
368429
consumer_release_qk_view = tlx.local_view(producer_qk0, bufIdx)
369430
tlx.barrier_arrive(consumer_release_qk_view, 1)
@@ -382,10 +443,8 @@ def gdpa_kernel_tma_ws_blackwell(
382443
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
383444
accum_cnt_outer, NUM_BUFFERS_O
384445
)
385-
o0_view = tlx.local_view(
386-
o0_buf, bufIdx_o_outer
387-
) # FIXME: index for the last iteration
388-
o0 = tlx.local_load(o0_view) # , tlx.storage_kind.tmem)
446+
o0_view = tlx.local_view(o0_buf, bufIdx_o_outer)
447+
o0 = tlx.local_load(o0_view)
389448
# release o0 here
390449
consumer_release_o0_view = tlx.local_view(
391450
producer_o0, bufIdx_o_outer
@@ -437,24 +496,29 @@ def gdpa_kernel_tma_ws_blackwell(
437496
for start_n in range(lo, hi, BLOCK_N):
438497
start_n = tl.multiple_of(start_n, BLOCK_N)
439498
## communication channel for qk1, p1
440-
# qk in tmem, output p in tmem
441499
bufIdx = accum_cnt % NUM_BUFFERS_QK
442500
phase = (accum_cnt // NUM_BUFFERS_QK) & 1
443501
qk_view = tlx.local_view(qk1_buf, bufIdx)
444502
consumer_qk_view = tlx.local_view(producer_commit_qk1, bufIdx)
445503
tlx.barrier_wait(consumer_qk_view, phase)
446-
qk1 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
447-
# ConsumerWait for qk, ProducerAcquire for p
448-
# if activation_enum_int == 3:
449-
p1 = fast_gelu(qk1)
450-
p1 *= qk_scale
451-
if USE_ON_DEVICE_TMA:
452-
p1 = p1.to(V.dtype.element_ty) # v_dtype)
453-
else:
454-
p1 = p1.to(tlx.dtype_of(v_desc))
455-
qk_view = tlx.local_view(qk1_buf, bufIdx)
456-
p1_view = tlx.local_reinterpret(qk_view, dtype)
457-
tlx.local_store(p1_view, p1) # , tlx.storage_kind.tmem)
504+
505+
# qk_view: BLOCK_M // 2, HEAD_DIM
506+
qk_view_1st = tlx.subslice(qk_view, 0, HEAD_DIM // 2)
507+
qk0 = tlx.local_load(qk_view_1st)
508+
p0 = fast_gelu(qk0)
509+
p0 = p0.to(dtype)
510+
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
511+
tlx.local_store(p0_view, p0)
512+
513+
qk_view_2nd = tlx.subslice(
514+
qk_view, HEAD_DIM // 2, HEAD_DIM // 2
515+
)
516+
qk0 = tlx.local_load(qk_view_2nd)
517+
p0 = fast_gelu(qk0)
518+
p0 = p0.to(dtype)
519+
p0_view = tlx.local_reinterpret(qk_view_2nd, dtype)
520+
tlx.local_store(p0_view, p0)
521+
458522
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
459523
consumer_release_qk_view = tlx.local_view(producer_qk1, bufIdx)
460524
tlx.barrier_arrive(consumer_release_qk_view, 1)
@@ -477,10 +541,8 @@ def gdpa_kernel_tma_ws_blackwell(
477541
strides=[HEAD_DIM * H, 1],
478542
block_shape=[BLOCK_M // 2, BLOCK_D],
479543
)
480-
o1_view = tlx.local_view(
481-
o1_buf, bufIdx_o_outer
482-
) # FIXME: should be 0
483-
o1 = tlx.local_load(o1_view) # , tlx.storage_kind.tmem)
544+
o1_view = tlx.local_view(o1_buf, bufIdx_o_outer)
545+
o1 = tlx.local_load(o1_view)
484546
# release o1 here
485547
consumer_release_o1_view = tlx.local_view(
486548
producer_o1, bufIdx_o_outer
@@ -620,7 +682,6 @@ def gdpa_kernel_tma_ws_blackwell(
620682
consumer_p0_view, phase_p
621683
) # consumer wait for p0 due to reuse of p0 and qk0
622684
# reinterpret qk0 as p0
623-
# p0_view = _reinterpret(qk0_buf, bufIdx_p)
624685
qk_view = tlx.local_view(qk0_buf, bufIdx_p)
625686
p0_view = tlx.local_reinterpret(qk_view, dtype)
626687

@@ -712,7 +773,6 @@ def gdpa_kernel_tma_ws_blackwell(
712773
consumer_release_kv, bufIdx_v
713774
)
714775
# reinterpret as p1
715-
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
716776
qk_view = tlx.local_view(qk1_buf, bufIdx_qk1)
717777
p1_view = tlx.local_reinterpret(qk_view, dtype)
718778
tlx.async_dot( # p1 . v from previous iteration
@@ -773,7 +833,6 @@ def gdpa_kernel_tma_ws_blackwell(
773833
consumer_p0_view, phase_qk
774834
) # consumer wait for p0 use producer_qk0 due to reuse
775835
# reinterpret as p0
776-
# p0_view = _reinterpret(qk0_buf, bufIdx_qk)
777836
qk_view = tlx.local_view(qk0_buf, bufIdx_qk)
778837
p0_view = tlx.local_reinterpret(qk_view, dtype)
779838

@@ -822,7 +881,6 @@ def gdpa_kernel_tma_ws_blackwell(
822881
tlx.barrier_wait(
823882
consumer_p1_view, phase_qk1
824883
) # consumer wait for p1 due to reuse of p1 and qk1
825-
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
826884
qk_view = tlx.local_view(qk1_buf, bufIdx_qk1)
827885
p1_view = tlx.local_reinterpret(qk_view, dtype)
828886

0 commit comments

Comments
 (0)