Skip to content

Commit 1890a9e

Browse files
committed
inline fast_gelu also add tmem_store
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 458a903 commit 1890a9e

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

tritonbench/operators/gdpa/gdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,7 @@ def grid_tma_persistent(META):
21972197
ad_to_request_offset = create_dummy_tensor(query)
21982198

21992199
activation_enum_int = activation_string_to_int(activation)
2200+
print("activation_enum_int", activation, activation_enum_int)
22002201
kernel_info = capture_triton(kernel_fn)[grid](
22012202
q,
22022203
query_offset,

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,11 @@ def _do_activation(
415415
tlx.barrier_wait(consumer_qk_view, phase)
416416
qk = tlx.local_load(qk_view, tlx.storage_kind.tmem)
417417
# ConsumerWait for qk, ProducerAcquire for p
418-
if activation_enum_int == 3:
419-
p = fast_gelu(qk)
420-
else:
421-
p = qk
418+
# hardcode to fast_gelu
419+
# if activation_enum_int == 3:
420+
p = fast_gelu(qk)
421+
# else:
422+
# p = qk
422423

423424
p *= qk_scale
424425
p = p.to(v_dtype)
@@ -604,12 +605,24 @@ def gdpa_kernel_tma_ws_blackwell(
604605
tlx.barrier_wait(consumer_qk_view, phase)
605606
qk0 = tlx.local_load(qk_view, tlx.storage_kind.tmem)
606607
# ConsumerWait for qk, ProducerAcquire for p
607-
if activation_enum_int == 3:
608-
p0 = fast_gelu(qk0)
609-
else:
610-
p0 = qk0
608+
# if activation_enum_int == 3:
609+
p0 = (
610+
qk0
611+
* 0.5
612+
* (
613+
1
614+
+ tanh_approx_fp32(
615+
0.7978845608 * qk0 * (1.0 + 0.044715 * qk0 * qk0)
616+
)
617+
)
618+
) # fast_gelu(qk0)
619+
# else:
620+
# p0 = qk0
611621
p0 *= qk_scale
612622
p0 = p0.to(V.dtype.element_ty) # v_dtype)
623+
qk_view = tlx.local_view(qk0_buf, bufIdx)
624+
p0_view = tlx.local_reinterpret(qk_view, tl.float16)
625+
tlx.local_store(p0_view, p0, tlx.storage_kind.tmem)
613626
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
614627
consumer_release_qk_view = tlx.local_view(producer_qk0, bufIdx)
615628
tlx.barrier_arrive(consumer_release_qk_view, 1)
@@ -692,12 +705,24 @@ def gdpa_kernel_tma_ws_blackwell(
692705
tlx.barrier_wait(consumer_qk_view, phase)
693706
qk1 = tlx.local_load(qk_view, tlx.storage_kind.tmem)
694707
# ConsumerWait for qk, ProducerAcquire for p
695-
if activation_enum_int == 3:
696-
p1 = fast_gelu(qk1)
697-
else:
698-
p1 = qk1
708+
# if activation_enum_int == 3:
709+
p1 = (
710+
qk1
711+
* 0.5
712+
* (
713+
1
714+
+ tanh_approx_fp32(
715+
0.7978845608 * qk1 * (1.0 + 0.044715 * qk1 * qk1)
716+
)
717+
)
718+
) # fast_gelu(qk1)
719+
# else:
720+
# p1 = qk1
699721
p1 *= qk_scale
700722
p1 = p1.to(V.dtype.element_ty) # v_dtype)
723+
qk_view = tlx.local_view(qk1_buf, bufIdx)
724+
p1_view = tlx.local_reinterpret(qk_view, tl.float16)
725+
tlx.local_store(p1_view, p1, tlx.storage_kind.tmem)
701726
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
702727
consumer_release_qk_view = tlx.local_view(producer_qk1, bufIdx)
703728
tlx.barrier_arrive(consumer_release_qk_view, 1)
@@ -1233,6 +1258,7 @@ def gdpa_kernel_tma_ws_blackwell(
12331258
],
12341259
v_full_view,
12351260
)
1261+
accum_count_k += 1
12361262

12371263
accum_count_q += 1
12381264

@@ -1355,6 +1381,7 @@ def grid_tma_persistent(META):
13551381
vstrides = v.stride()
13561382

13571383
activation_enum_int = activation_string_to_int(activation)
1384+
print("activation_enum_int", activation, activation_enum_int)
13581385

13591386
gdpa_kernel_tma_ws_blackwell[grid_tma_persistent](
13601387
q,

0 commit comments

Comments
 (0)