@@ -104,13 +104,6 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS):
104
104
return bufIdx , phase
105
105
106
106
107
- @triton .jit
108
- def _reinterpret (qk_buf , bufIdx_qk ):
109
- qk_view = tlx .local_view (qk_buf , bufIdx_qk )
110
- p_view = tlx .local_reinterpret (qk_view , tl .float16 )
111
- return p_view
112
-
113
-
114
107
@triton .jit
115
108
def _load_tma (
116
109
bufIdx , phase , empty_bars , full_bars , buffers , desc , offset_1 , offset_0 , num_bytes
@@ -146,6 +139,69 @@ def _load_tma(
146
139
# qk0, qk1: producers
147
140
# p0, p1: sharing tmem spaces, and barriers with qk0, qk1 (consumers)
148
141
# o0, o1
142
+
143
+
144
+ @triton .jit
145
+ def _add_f32x2 (a , b ):
146
+ return tl .inline_asm_elementwise (
147
+ """
148
+ {
149
+ .reg .b64 ra, rb, rc;
150
+ mov.b64 ra, { $2, $3 };
151
+ mov.b64 rb, { $4, $5 };
152
+ add.f32x2 rc, ra, rb;
153
+ mov.b64 { $0, $1 }, rc;
154
+ }
155
+ """ ,
156
+ "=r,=r,r,r,r,r" ,
157
+ [a , b ],
158
+ dtype = tl .float32 ,
159
+ is_pure = True ,
160
+ pack = 2 ,
161
+ )
162
+
163
+
164
+ @triton .jit
165
+ def _mul_f32x2 (a , b ):
166
+ return tl .inline_asm_elementwise (
167
+ """
168
+ {
169
+ .reg .b64 ra, rb, rc;
170
+ mov.b64 ra, { $2, $3 };
171
+ mov.b64 rb, { $4, $5 };
172
+ mul.f32x2 rc, ra, rb;
173
+ mov.b64 { $0, $1 }, rc;
174
+ }
175
+ """ ,
176
+ "=r,=r,r,r,r,r" ,
177
+ [a , b ],
178
+ dtype = tl .float32 ,
179
+ is_pure = True ,
180
+ pack = 2 ,
181
+ )
182
+
183
+
184
+ @triton .jit
185
+ def _fma_f32x2 (a , b , c ):
186
+ return tl .inline_asm_elementwise (
187
+ """
188
+ {
189
+ .reg .b64 ra, rb, rc, rd;
190
+ mov.b64 ra, { $2, $3 };
191
+ mov.b64 rb, { $4, $5 };
192
+ mov.b64 rc, { $6, $7 };
193
+ fma.rn.f32x2 rd, ra, rb, rc;
194
+ mov.b64 { $0, $1 }, rd;
195
+ }
196
+ """ ,
197
+ "=r,=r,r,r,r,r,r,r" ,
198
+ [a , b , c ],
199
+ dtype = tl .float32 ,
200
+ is_pure = True ,
201
+ pack = 2 ,
202
+ )
203
+
204
+
149
205
@triton .jit
150
206
def tanh_approx_fp32 (x ):
151
207
output = tl .inline_asm_elementwise (
@@ -164,7 +220,16 @@ def tanh_approx_fp32(x):
164
220
# typical configuration is 3/fast_gelu
165
221
@triton .jit
166
222
def fast_gelu (x ):
167
- return x * 0.5 * (1 + tanh_approx_fp32 (0.7978845608 * x * (1.0 + 0.044715 * x * x )))
223
+ # following D80750725
224
+ # WAS: x * 0.5 * (1 + tanh_approx_fp32(0.7978845608 * x * (1.0 + 0.044715 * x * x))) * scaling
225
+ # NOW: x * tanh((c1 * x * x + c0)*x) + x
226
+ c1 = 0.0356774081
227
+ c0 = 0.7978845608
228
+ square = _mul_f32x2 (x , x )
229
+ inner = _fma_f32x2 (c1 , square , c0 )
230
+ inner = _mul_f32x2 (inner , x )
231
+ out = _fma_f32x2 (x , tanh_approx_fp32 (inner ), x )
232
+ return out
168
233
169
234
170
235
@triton .autotune (
@@ -255,7 +320,7 @@ def gdpa_kernel_tma_ws_blackwell(
255
320
)
256
321
257
322
if USE_ON_DEVICE_TMA :
258
- dtype = V .dtype .element_ty # v_dtype)
323
+ dtype = V .dtype .element_ty
259
324
else :
260
325
dtype = tlx .dtype_of (v_desc )
261
326
@@ -287,18 +352,14 @@ def gdpa_kernel_tma_ws_blackwell(
287
352
consumer_release_q0 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q , arrive_count = 1 )
288
353
consumer_release_q1 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_Q , arrive_count = 1 )
289
354
consumer_kv = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_KV , arrive_count = 1 )
290
- # consumer_v = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV, arrive_count=1)
291
355
consumer_release_kv = tlx .alloc_barriers (
292
356
num_barriers = NUM_BUFFERS_KV , arrive_count = 1
293
357
)
294
358
tlx .barrier_arrive (consumer_release_kv [0 ], 1 )
295
359
tlx .barrier_arrive (consumer_release_kv [1 ], 1 )
296
360
tlx .barrier_arrive (consumer_release_kv [2 ], 1 )
297
- # consumer_release_v = tlx.alloc_barriers(num_barriers=NUM_BUFFERS_KV, arrive_count=1)
298
361
299
- # producer_qk0 == consumer_release_qk0
300
362
producer_qk0 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_QK , arrive_count = 1 )
301
- # producer_commit_qk0 == consumer_qk0
302
363
producer_commit_qk0 = tlx .alloc_barriers (
303
364
num_barriers = NUM_BUFFERS_QK , arrive_count = 1
304
365
)
@@ -307,13 +368,9 @@ def gdpa_kernel_tma_ws_blackwell(
307
368
num_barriers = NUM_BUFFERS_QK , arrive_count = 1
308
369
)
309
370
310
- producer_o0 = tlx .alloc_barriers (
311
- num_barriers = NUM_BUFFERS_O , arrive_count = 1
312
- ) # only acquire for the first iteration
371
+ producer_o0 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_O , arrive_count = 1 )
313
372
producer_commit_o0 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_O , arrive_count = 1 )
314
- producer_o1 = tlx .alloc_barriers (
315
- num_barriers = NUM_BUFFERS_O , arrive_count = 1
316
- ) # only acquire for the first iteration
373
+ producer_o1 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_O , arrive_count = 1 )
317
374
producer_commit_o1 = tlx .alloc_barriers (num_barriers = NUM_BUFFERS_O , arrive_count = 1 )
318
375
319
376
with tlx .async_tasks ():
@@ -343,27 +400,31 @@ def gdpa_kernel_tma_ws_blackwell(
343
400
for start_n in range (lo , hi , BLOCK_N ):
344
401
start_n = tl .multiple_of (start_n , BLOCK_N )
345
402
# tl.device_print("default start_n", start_n)
346
- ## communication channel for qk0, p0
347
- # qk in tmem, output p in tmem
348
403
bufIdx = accum_cnt % NUM_BUFFERS_QK
349
404
phase = (accum_cnt // NUM_BUFFERS_QK ) & 1
350
405
qk_view = tlx .local_view (qk0_buf , bufIdx )
351
406
consumer_qk_view = tlx .local_view (producer_commit_qk0 , bufIdx )
352
407
# tl.device_print("default producer_commit_qk0", accum_cnt)
353
408
# tl.device_print("default producer_commit_qk0_phase", phase)
354
409
tlx .barrier_wait (consumer_qk_view , phase )
355
- qk0 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
356
- # ConsumerWait for qk, ProducerAcquire for p
357
- # if activation_enum_int == 3:
410
+
411
+ # qk_view: BLOCK_M // 2, HEAD_DIM
412
+ qk_view_1st = tlx .subslice (qk_view , 0 , HEAD_DIM // 2 )
413
+ qk0 = tlx .local_load (qk_view_1st )
358
414
p0 = fast_gelu (qk0 )
359
- p0 *= qk_scale
360
- if USE_ON_DEVICE_TMA :
361
- p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
362
- else :
363
- p0 = p0 .to (tlx .dtype_of (v_desc ))
364
- qk_view = tlx .local_view (qk0_buf , bufIdx )
365
- p0_view = tlx .local_reinterpret (qk_view , dtype )
366
- tlx .local_store (p0_view , p0 ) # , tlx.storage_kind.tmem)
415
+ p0 = p0 .to (dtype )
416
+ p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
417
+ tlx .local_store (p0_view , p0 )
418
+
419
+ qk_view_2nd = tlx .subslice (
420
+ qk_view , HEAD_DIM // 2 , HEAD_DIM // 2
421
+ )
422
+ qk0 = tlx .local_load (qk_view_2nd )
423
+ p0 = fast_gelu (qk0 )
424
+ p0 = p0 .to (dtype )
425
+ p0_view = tlx .local_reinterpret (qk_view_2nd , dtype )
426
+ tlx .local_store (p0_view , p0 )
427
+
367
428
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
368
429
consumer_release_qk_view = tlx .local_view (producer_qk0 , bufIdx )
369
430
tlx .barrier_arrive (consumer_release_qk_view , 1 )
@@ -382,10 +443,8 @@ def gdpa_kernel_tma_ws_blackwell(
382
443
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
383
444
accum_cnt_outer , NUM_BUFFERS_O
384
445
)
385
- o0_view = tlx .local_view (
386
- o0_buf , bufIdx_o_outer
387
- ) # FIXME: index for the last iteration
388
- o0 = tlx .local_load (o0_view ) # , tlx.storage_kind.tmem)
446
+ o0_view = tlx .local_view (o0_buf , bufIdx_o_outer )
447
+ o0 = tlx .local_load (o0_view )
389
448
# release o0 here
390
449
consumer_release_o0_view = tlx .local_view (
391
450
producer_o0 , bufIdx_o_outer
@@ -437,24 +496,29 @@ def gdpa_kernel_tma_ws_blackwell(
437
496
for start_n in range (lo , hi , BLOCK_N ):
438
497
start_n = tl .multiple_of (start_n , BLOCK_N )
439
498
## communication channel for qk1, p1
440
- # qk in tmem, output p in tmem
441
499
bufIdx = accum_cnt % NUM_BUFFERS_QK
442
500
phase = (accum_cnt // NUM_BUFFERS_QK ) & 1
443
501
qk_view = tlx .local_view (qk1_buf , bufIdx )
444
502
consumer_qk_view = tlx .local_view (producer_commit_qk1 , bufIdx )
445
503
tlx .barrier_wait (consumer_qk_view , phase )
446
- qk1 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
447
- # ConsumerWait for qk, ProducerAcquire for p
448
- # if activation_enum_int == 3:
449
- p1 = fast_gelu (qk1 )
450
- p1 *= qk_scale
451
- if USE_ON_DEVICE_TMA :
452
- p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
453
- else :
454
- p1 = p1 .to (tlx .dtype_of (v_desc ))
455
- qk_view = tlx .local_view (qk1_buf , bufIdx )
456
- p1_view = tlx .local_reinterpret (qk_view , dtype )
457
- tlx .local_store (p1_view , p1 ) # , tlx.storage_kind.tmem)
504
+
505
+ # qk_view: BLOCK_M // 2, HEAD_DIM
506
+ qk_view_1st = tlx .subslice (qk_view , 0 , HEAD_DIM // 2 )
507
+ qk0 = tlx .local_load (qk_view_1st )
508
+ p0 = fast_gelu (qk0 )
509
+ p0 = p0 .to (dtype )
510
+ p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
511
+ tlx .local_store (p0_view , p0 )
512
+
513
+ qk_view_2nd = tlx .subslice (
514
+ qk_view , HEAD_DIM // 2 , HEAD_DIM // 2
515
+ )
516
+ qk0 = tlx .local_load (qk_view_2nd )
517
+ p0 = fast_gelu (qk0 )
518
+ p0 = p0 .to (dtype )
519
+ p0_view = tlx .local_reinterpret (qk_view_2nd , dtype )
520
+ tlx .local_store (p0_view , p0 )
521
+
458
522
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
459
523
consumer_release_qk_view = tlx .local_view (producer_qk1 , bufIdx )
460
524
tlx .barrier_arrive (consumer_release_qk_view , 1 )
@@ -477,10 +541,8 @@ def gdpa_kernel_tma_ws_blackwell(
477
541
strides = [HEAD_DIM * H , 1 ],
478
542
block_shape = [BLOCK_M // 2 , BLOCK_D ],
479
543
)
480
- o1_view = tlx .local_view (
481
- o1_buf , bufIdx_o_outer
482
- ) # FIXME: should be 0
483
- o1 = tlx .local_load (o1_view ) # , tlx.storage_kind.tmem)
544
+ o1_view = tlx .local_view (o1_buf , bufIdx_o_outer )
545
+ o1 = tlx .local_load (o1_view )
484
546
# release o1 here
485
547
consumer_release_o1_view = tlx .local_view (
486
548
producer_o1 , bufIdx_o_outer
@@ -620,7 +682,6 @@ def gdpa_kernel_tma_ws_blackwell(
620
682
consumer_p0_view , phase_p
621
683
) # consumer wait for p0 due to reuse of p0 and qk0
622
684
# reinterpret qk0 as p0
623
- # p0_view = _reinterpret(qk0_buf, bufIdx_p)
624
685
qk_view = tlx .local_view (qk0_buf , bufIdx_p )
625
686
p0_view = tlx .local_reinterpret (qk_view , dtype )
626
687
@@ -712,7 +773,6 @@ def gdpa_kernel_tma_ws_blackwell(
712
773
consumer_release_kv , bufIdx_v
713
774
)
714
775
# reinterpret as p1
715
- # p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
716
776
qk_view = tlx .local_view (qk1_buf , bufIdx_qk1 )
717
777
p1_view = tlx .local_reinterpret (qk_view , dtype )
718
778
tlx .async_dot ( # p1 . v from previous iteration
@@ -773,7 +833,6 @@ def gdpa_kernel_tma_ws_blackwell(
773
833
consumer_p0_view , phase_qk
774
834
) # consumer wait for p0 use producer_qk0 due to reuse
775
835
# reinterpret as p0
776
- # p0_view = _reinterpret(qk0_buf, bufIdx_qk)
777
836
qk_view = tlx .local_view (qk0_buf , bufIdx_qk )
778
837
p0_view = tlx .local_reinterpret (qk_view , dtype )
779
838
@@ -822,7 +881,6 @@ def gdpa_kernel_tma_ws_blackwell(
822
881
tlx .barrier_wait (
823
882
consumer_p1_view , phase_qk1
824
883
) # consumer wait for p1 due to reuse of p1 and qk1
825
- # p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
826
884
qk_view = tlx .local_view (qk1_buf , bufIdx_qk1 )
827
885
p1_view = tlx .local_reinterpret (qk_view , dtype )
828
886
0 commit comments