8
8
from triton .experimental .gluon import language as gl
9
9
from triton .experimental .gluon .language .nvidia .blackwell import (
10
10
allocate_tensor_memory ,
11
+ float2 ,
11
12
get_tmem_32x32b_reg_layout ,
12
13
mbarrier ,
13
14
tcgen05_commit ,
@@ -69,6 +70,7 @@ def increment(self):
69
70
70
71
71
72
def Channel (T , alloc_fn ):
73
+
72
74
@aggregate
73
75
class ChannelType :
74
76
mem : T
@@ -243,9 +245,7 @@ class AttentionConfig:
243
245
alpha_2d_layout : gl .constexpr
244
246
245
247
num_kv_buffers : gl .constexpr
246
- use_fadd2_reduce : gl .constexpr
247
248
use_exp2_turnstile : gl .constexpr
248
- use_ffma2_scale_rowmax : gl .constexpr
249
249
250
250
def __init__ (
251
251
self ,
@@ -290,13 +290,13 @@ def __init__(
290
290
qk_instr_shape = get_mma_instr_shape (self .qk_shape , gl .float32 )
291
291
o_instr_shape = get_mma_instr_shape (self .o_shape , gl .float32 )
292
292
self .qk_tmem_layout = gl .constexpr (
293
- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = True )
293
+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
294
294
)
295
295
self .o_tmem_layout = gl .constexpr (
296
- TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), unpacked = True )
296
+ TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), col_stride = 1 )
297
297
)
298
298
self .p_tmem_layout = gl .constexpr (
299
- TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = False )
299
+ TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), col_stride = 1 )
300
300
)
301
301
302
302
self .qk_layout = gl .constexpr (
@@ -321,17 +321,13 @@ def __init__(
321
321
gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ])
322
322
)
323
323
324
- is_fp16 = dtype .value in [gl .float16 , gl .bfloat16 ]
324
+ is_fp16 = self . dtype .value in [gl .float16 , gl .bfloat16 ]
325
325
if is_fp16 :
326
326
self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
327
327
else :
328
328
self .num_kv_buffers = gl .constexpr (4 if HEAD_DIM == 128 else 8 )
329
329
330
- self .use_fadd2_reduce = gl .constexpr (HEAD_DIM == 64 )
331
330
self .use_exp2_turnstile = gl .constexpr (HEAD_DIM == 64 )
332
- self .use_ffma2_scale_rowmax = gl .constexpr (
333
- HEAD_DIM == 128 or is_fp16 == (STAGE == 3 )
334
- )
335
331
336
332
@gluon .jit
337
333
def get_program (self , pid_m , pid_n ):
@@ -421,113 +417,6 @@ def get_loop_bounds(self, STAGE: gl.constexpr):
421
417
return lo , hi
422
418
423
419
424
- # ===-----------------------------------------------------------------------===#
425
- # float2
426
- # ===-----------------------------------------------------------------------===#
427
-
428
-
429
- @gluon .jit
430
- def _add_f32x2 (a , b ):
431
- return gl .inline_asm_elementwise (
432
- """
433
- {
434
- .reg .b64 ra, rb, rc;
435
- mov.b64 ra, { $2, $3 };
436
- mov.b64 rb, { $4, $5 };
437
- add.f32x2 rc, ra, rb;
438
- mov.b64 { $0, $1 }, rc;
439
- }
440
- """ ,
441
- "=r,=r,r,r,r,r" ,
442
- [a , b ],
443
- dtype = gl .float32 ,
444
- is_pure = True ,
445
- pack = 2 ,
446
- )
447
-
448
-
449
- @gluon .jit
450
- def _mul_f32x2 (a , b ):
451
- return gl .inline_asm_elementwise (
452
- """
453
- {
454
- .reg .b64 ra, rb, rc;
455
- mov.b64 ra, { $2, $3 };
456
- mov.b64 rb, { $4, $5 };
457
- mul.f32x2 rc, ra, rb;
458
- mov.b64 { $0, $1 }, rc;
459
- }
460
- """ ,
461
- "=r,=r,r,r,r,r" ,
462
- [a , b ],
463
- dtype = gl .float32 ,
464
- is_pure = True ,
465
- pack = 2 ,
466
- )
467
-
468
-
469
- @gluon .jit
470
- def _fma_f32x2 (a , b , c ):
471
- return gl .inline_asm_elementwise (
472
- """
473
- {
474
- .reg .b64 ra, rb, rc, rd;
475
- mov.b64 ra, { $2, $3 };
476
- mov.b64 rb, { $4, $5 };
477
- mov.b64 rc, { $6, $7 };
478
- fma.rn.f32x2 rd, ra, rb, rc;
479
- mov.b64 { $0, $1 }, rd;
480
- }
481
- """ ,
482
- "=r,=r,r,r,r,r,r,r" ,
483
- [a , b , c ],
484
- dtype = gl .float32 ,
485
- is_pure = True ,
486
- pack = 2 ,
487
- )
488
-
489
-
490
- @gluon .jit
491
- def _reduce_fadd2 (p0a , p1a , p0b , p1b ):
492
- return gl .inline_asm_elementwise (
493
- """
494
- {
495
- .reg .b64 rc, ra, rb;
496
- mov.b64 ra, { $2, $4 };
497
- mov.b64 rb, { $3, $5 };
498
- add.f32x2 rc, ra, rb;
499
- mov.b64 { $0, $1 }, rc;
500
- }
501
- """ ,
502
- "=r,=r,r,r,r,r" ,
503
- [p0a , p0b , p1a , p1b ],
504
- dtype = [gl .float32 , gl .float32 ],
505
- is_pure = True ,
506
- pack = 1 ,
507
- )
508
-
509
-
510
- @gluon .jit
511
- def _pairwise_fma_f32x2 (a0 , b0 , c0 , a1 , b1 , c1 ):
512
- return gl .inline_asm_elementwise (
513
- """
514
- {
515
- .reg .b64 rd, ra, rb, rc;
516
- mov.b64 ra, { $2, $5 };
517
- mov.b64 rb, { $3, $6 };
518
- mov.b64 rc, { $4, $7 };
519
- fma.rn.f32x2 rd, ra, rb, rc;
520
- mov.b64 { $0, $1 }, rd;
521
- }
522
- """ ,
523
- "=r,=r,r,r,r,r,r,r" ,
524
- [a0 , b0 , c0 , a1 , b1 , c1 ],
525
- dtype = [gl .float32 , gl .float32 ],
526
- is_pure = True ,
527
- pack = 1 ,
528
- )
529
-
530
-
531
420
# ===-----------------------------------------------------------------------===#
532
421
# _gluon_attn
533
422
# ===-----------------------------------------------------------------------===#
@@ -542,15 +431,15 @@ def _borrow_s_as_p(config, s_tmem):
542
431
@gluon .jit
543
432
def _borrow_s_as_alpha (config , s_tmem ):
544
433
alpha_tmem = s_tmem .slice (config .BLOCK_N // 2 , 1 )
545
- alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
434
+ alpha_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
546
435
return alpha_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], alpha_layout )
547
436
548
437
549
438
@gluon .jit
550
439
def _borrow_s_for_epilogue (config , s_tmem ):
551
440
m_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 1 , 1 )
552
441
l_i_tmem = s_tmem .slice (config .BLOCK_N // 2 + 2 , 1 )
553
- layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = True )
442
+ layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], col_stride = 1 )
554
443
m_i_tmem = m_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
555
444
l_i_tmem = l_i_tmem ._reinterpret (gl .float32 , [config .SPLIT_M , 1 ], layout )
556
445
return m_i_tmem , l_i_tmem
@@ -798,8 +687,7 @@ def _softmax_inner_loop(
798
687
corr_bar , #
799
688
offs_m ,
800
689
m_i ,
801
- l_i0 ,
802
- l_i1 ,
690
+ l_i ,
803
691
STAGE : gl .constexpr ,
804
692
):
805
693
lo , hi = prog .get_loop_bounds (STAGE )
@@ -821,11 +709,10 @@ def _softmax_inner_loop(
821
709
)
822
710
mbarrier .arrive (corr_bar , count = 1 )
823
711
824
- if config .use_ffma2_scale_rowmax :
825
- qk = _fma_f32x2 (qk , gl .full_like (qk , config .qk_scale ), - m_ij [:, None ])
826
- else :
827
- qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
828
- qk = _add_f32x2 (qk , - m_ij [:, None ])
712
+ rowmax = float2 .pack (- m_ij [:, None ].broadcast_to (qk .shape ), axis = 1 )
713
+ qk = float2 .pack (qk , axis = 1 )
714
+ qk = float2 .fma (qk , float2 .full_like (qk , config .qk_scale ), rowmax )
715
+ qk = float2 .unpack (qk , axis = 1 )
829
716
830
717
# Force the softmax partitions to take turns in the EX2 section. This
831
718
# prevents contention for the EX2 unit and improves utilization.
@@ -844,24 +731,12 @@ def _softmax_inner_loop(
844
731
if config .use_exp2_turnstile :
845
732
mbarrier .arrive (exp_bar , count = 1 )
846
733
847
- if config .use_fadd2_reduce :
848
- p0 , p1 = _split_n (p )
849
- l_ij0 , l_ij1 = gl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
850
- # This is a difference of 1 SASS instruction but it dramatically
851
- # affects instruction scheduling.
852
- alpha = gl .convert_layout (alpha , l_i0 .type .layout , assert_trivial = True )
853
- if config .dtype == gl .float8e5 :
854
- l_i0 , l_i1 = _pairwise_fma_f32x2 (l_i0 , alpha , l_ij0 , l_i1 , alpha , l_ij1 )
855
- else :
856
- l_i0 = l_i0 * alpha + l_ij0
857
- l_i1 = l_i1 * alpha + l_ij1
858
- else :
859
- l_ij = gl .sum (p , axis = 1 )
860
- l_i0 = l_i0 * alpha + l_ij
861
-
734
+ l_ij = float2 .pack2 (* _split_n (p )).sum (axis = 1 )
735
+ alpha = gl .convert_layout (alpha , l_i .value .type .layout , assert_trivial = True )
736
+ l_i = float2 .fma (l_i , float2 .pack2 (alpha , alpha ), l_ij )
862
737
m_i = m_ij
863
738
864
- return m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile
739
+ return m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile
865
740
866
741
867
742
@gluon .jit
@@ -876,11 +751,7 @@ def _softmax_tile(
876
751
exp_turnstile ,
877
752
):
878
753
qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
879
- sum_layout : gl .constexpr = (
880
- _get_split_n_layout (config .qk_layout )
881
- if config .use_fadd2_reduce
882
- else config .qk_layout
883
- )
754
+ sum_layout : gl .constexpr = _get_split_n_layout (config .qk_layout )
884
755
885
756
s_consumer = s_chnl .create_consumer ()
886
757
corr_producer = corr_chnl .create_producer ()
@@ -894,17 +765,12 @@ def _softmax_tile(
894
765
offs_m += gl .arange (tile_id * config .SPLIT_M , (1 + tile_id ) * config .SPLIT_M )
895
766
896
767
m_i = gl .full ([config .SPLIT_M ], - float ("inf" ), gl .float32 , qk_slice_dim1 )
897
- l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
898
768
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
899
- if config .use_fadd2_reduce :
900
- l_i1 = gl .full (
901
- [config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout )
902
- )
903
- else :
904
- l_i1 = 0
769
+ l_i = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl .SliceLayout (1 , sum_layout ))
770
+ l_i = float2 .pack2 (l_i , l_i )
905
771
906
772
if STAGE & 1 :
907
- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
773
+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
908
774
_softmax_inner_loop ( #
909
775
tile_id ,
910
776
config ,
@@ -915,13 +781,12 @@ def _softmax_tile(
915
781
corr_bar , #
916
782
offs_m ,
917
783
m_i ,
918
- l_i0 ,
919
- l_i1 ,
784
+ l_i ,
920
785
STAGE = 4 - STAGE ,
921
786
)
922
787
)
923
788
if STAGE & 2 :
924
- m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = (
789
+ m_i , l_i , corr_bar , s_consumer , corr_producer , exp_turnstile = (
925
790
_softmax_inner_loop ( #
926
791
tile_id ,
927
792
config ,
@@ -932,16 +797,12 @@ def _softmax_tile(
932
797
corr_bar , #
933
798
offs_m ,
934
799
m_i ,
935
- l_i0 ,
936
- l_i1 ,
800
+ l_i ,
937
801
STAGE = 2 ,
938
802
)
939
803
)
940
-
941
- if config .use_fadd2_reduce :
942
- l_i = l_i0 + l_i1
943
- else :
944
- l_i = l_i0
804
+ l_i0 , l_i1 = float2 .unpack2 (l_i )
805
+ l_i = l_i0 + l_i1
945
806
946
807
s_tmem , s_bar , s_consumer = s_consumer .acquire ()
947
808
m_i_tmem , l_i_tmem = _borrow_s_for_epilogue (config , s_tmem )
@@ -1039,11 +900,14 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
1039
900
mbarrier .arrive (corr_bar , count = 1 )
1040
901
alpha = gl .convert_layout (alpha .reshape ([config .SPLIT_M ]), alpha_layout )
1041
902
903
+ alpha = float2 .pack (
904
+ alpha [:, None ].broadcast_to (config .o_shape [0 ], config .SPLIT_D ), axis = 1
905
+ )
1042
906
for i in gl .static_range (config .SPLIT_D_FACTOR ):
1043
907
o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
1044
- o = o_ref .load (config .o_splitn_layout )
1045
- o = _mul_f32x2 ( o , alpha [:, None ])
1046
- o_ref .store (o )
908
+ o = float2 . pack ( o_ref .load (config .o_splitn_layout ), axis = 1 )
909
+ o = o * alpha
910
+ o_ref .store (float2 . unpack ( o , axis = 1 ) )
1047
911
mbarrier .arrive (o_bar , count = 1 )
1048
912
return corr_consumer , o_consumer
1049
913
@@ -1081,12 +945,16 @@ def _attn_fwd_correction_epilogue(
1081
945
)
1082
946
SPLIT_N : gl .constexpr = o_smem .type .shape [1 ] // SPLIT_N_FACTOR
1083
947
1084
- scale = 1 / l_i
948
+ scale = float2 .pack (
949
+ (1 / l_i )[:, None ].broadcast_to (config .o_shape [0 ], SPLIT_N ), axis = 1
950
+ )
1085
951
for i in gl .static_range (SPLIT_N_FACTOR ):
1086
952
o_ref = o_tmem .slice (i * SPLIT_N , SPLIT_N )
1087
- o = o_ref .load (config .o_splitn_layout )
1088
- o = _mul_f32x2 (o , scale [:, None ])
1089
- o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (o .to (config .dtype ))
953
+ o = float2 .pack (o_ref .load (config .o_splitn_layout ), axis = 1 )
954
+ o = o * scale
955
+ o_smem .slice (i * SPLIT_N , SPLIT_N , dim = 1 ).store (
956
+ float2 .unpack (o , axis = 1 ).to (config .dtype )
957
+ )
1090
958
1091
959
fence_async_shared ()
1092
960
mbarrier .arrive (epi_bar , count = 1 )
0 commit comments