@@ -415,10 +415,11 @@ def _do_activation(
415
415
tlx .barrier_wait (consumer_qk_view , phase )
416
416
qk = tlx .local_load (qk_view , tlx .storage_kind .tmem )
417
417
# ConsumerWait for qk, ProducerAcquire for p
418
- if activation_enum_int == 3 :
419
- p = fast_gelu (qk )
420
- else :
421
- p = qk
418
+ # hardcode to fast_gelu
419
+ # if activation_enum_int == 3:
420
+ p = fast_gelu (qk )
421
+ # else:
422
+ # p = qk
422
423
423
424
p *= qk_scale
424
425
p = p .to (v_dtype )
@@ -604,12 +605,24 @@ def gdpa_kernel_tma_ws_blackwell(
604
605
tlx .barrier_wait (consumer_qk_view , phase )
605
606
qk0 = tlx .local_load (qk_view , tlx .storage_kind .tmem )
606
607
# ConsumerWait for qk, ProducerAcquire for p
607
- if activation_enum_int == 3 :
608
- p0 = fast_gelu (qk0 )
609
- else :
610
- p0 = qk0
608
+ # if activation_enum_int == 3:
609
+ p0 = (
610
+ qk0
611
+ * 0.5
612
+ * (
613
+ 1
614
+ + tanh_approx_fp32 (
615
+ 0.7978845608 * qk0 * (1.0 + 0.044715 * qk0 * qk0 )
616
+ )
617
+ )
618
+ ) # fast_gelu(qk0)
619
+ # else:
620
+ # p0 = qk0
611
621
p0 *= qk_scale
612
622
p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
623
+ qk_view = tlx .local_view (qk0_buf , bufIdx )
624
+ p0_view = tlx .local_reinterpret (qk_view , tl .float16 )
625
+ tlx .local_store (p0_view , p0 , tlx .storage_kind .tmem )
613
626
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
614
627
consumer_release_qk_view = tlx .local_view (producer_qk0 , bufIdx )
615
628
tlx .barrier_arrive (consumer_release_qk_view , 1 )
@@ -692,12 +705,24 @@ def gdpa_kernel_tma_ws_blackwell(
692
705
tlx .barrier_wait (consumer_qk_view , phase )
693
706
qk1 = tlx .local_load (qk_view , tlx .storage_kind .tmem )
694
707
# ConsumerWait for qk, ProducerAcquire for p
695
- if activation_enum_int == 3 :
696
- p1 = fast_gelu (qk1 )
697
- else :
698
- p1 = qk1
708
+ # if activation_enum_int == 3:
709
+ p1 = (
710
+ qk1
711
+ * 0.5
712
+ * (
713
+ 1
714
+ + tanh_approx_fp32 (
715
+ 0.7978845608 * qk1 * (1.0 + 0.044715 * qk1 * qk1 )
716
+ )
717
+ )
718
+ ) # fast_gelu(qk1)
719
+ # else:
720
+ # p1 = qk1
699
721
p1 *= qk_scale
700
722
p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
723
+ qk_view = tlx .local_view (qk1_buf , bufIdx )
724
+ p1_view = tlx .local_reinterpret (qk_view , tl .float16 )
725
+ tlx .local_store (p1_view , p1 , tlx .storage_kind .tmem )
701
726
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
702
727
consumer_release_qk_view = tlx .local_view (producer_qk1 , bufIdx )
703
728
tlx .barrier_arrive (consumer_release_qk_view , 1 )
@@ -1233,6 +1258,7 @@ def gdpa_kernel_tma_ws_blackwell(
1233
1258
],
1234
1259
v_full_view ,
1235
1260
)
1261
+ accum_count_k += 1
1236
1262
1237
1263
accum_count_q += 1
1238
1264
@@ -1355,6 +1381,7 @@ def grid_tma_persistent(META):
1355
1381
vstrides = v .stride ()
1356
1382
1357
1383
activation_enum_int = activation_string_to_int (activation )
1384
+ print ("activation_enum_int" , activation , activation_enum_int )
1358
1385
1359
1386
gdpa_kernel_tma_ws_blackwell [grid_tma_persistent ](
1360
1387
q ,
0 commit comments