@@ -293,9 +293,12 @@ class AttentionConfig:
293
293
alpha_2d_layout : gl .constexpr
294
294
295
295
num_kv_buffers : gl .constexpr
296
+ use_fadd2_reduce : gl .constexpr
297
+ use_exp2_turnstile : gl .constexpr
298
+ use_ffma2_scale_rowmax : gl .constexpr
296
299
297
- def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , dtype , num_warps ,
298
- SPLIT_D_FACTOR ):
300
+ def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , STAGE , dtype ,
301
+ num_warps , SPLIT_D_FACTOR ):
299
302
self .qk_scale = qk_scale
300
303
self .Z = Z
301
304
self .H = H
@@ -332,13 +335,16 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
332
335
(self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_D_FACTOR ), self .num_warps ))
333
336
self .alpha_2d_layout = gl .constexpr (gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ]))
334
337
335
- if dtype == gl .float16 :
336
- self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
337
- elif dtype == gl .bfloat16 :
338
+ is_fp16 = dtype .value in [gl .float16 , gl .bfloat16 ]
339
+ if is_fp16 :
338
340
self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
339
341
else :
340
342
self .num_kv_buffers = gl .constexpr (4 if HEAD_DIM == 128 else 8 )
341
343
344
+ self .use_fadd2_reduce = gl .constexpr (HEAD_DIM == 64 )
345
+ self .use_exp2_turnstile = gl .constexpr (HEAD_DIM == 64 )
346
+ self .use_ffma2_scale_rowmax = gl .constexpr (HEAD_DIM == 128 or is_fp16 == (STAGE == 3 ))
347
+
342
348
@gluon .jit
343
349
def get_program (self , pid_m , pid_n ):
344
350
start_m = pid_m
@@ -470,6 +476,68 @@ def _mul_f32x2(a, b):
470
476
)
471
477
472
478
479
+ @gluon .jit
480
+ def _fma_f32x2 (a , b , c ):
481
+ return gl .inline_asm_elementwise (
482
+ """
483
+ {
484
+ .reg .b64 ra, rb, rc, rd;
485
+ mov.b64 ra, { $2, $3 };
486
+ mov.b64 rb, { $4, $5 };
487
+ mov.b64 rc, { $6, $7 };
488
+ fma.rn.f32x2 rd, ra, rb, rc;
489
+ mov.b64 { $0, $1 }, rd;
490
+ }
491
+ """ ,
492
+ "=r,=r,r,r,r,r,r,r" ,
493
+ [a , b , c ],
494
+ dtype = gl .float32 ,
495
+ is_pure = True ,
496
+ pack = 2 ,
497
+ )
498
+
499
+
500
+ @gluon .jit
501
+ def _reduce_fadd2 (p0a , p1a , p0b , p1b ):
502
+ return gl .inline_asm_elementwise (
503
+ """
504
+ {
505
+ .reg .b64 rc, ra, rb;
506
+ mov.b64 ra, { $2, $4 };
507
+ mov.b64 rb, { $3, $5 };
508
+ add.f32x2 rc, ra, rb;
509
+ mov.b64 { $0, $1 }, rc;
510
+ }
511
+ """ ,
512
+ "=r,=r,r,r,r,r" ,
513
+ [p0a , p0b , p1a , p1b ],
514
+ dtype = [gl .float32 , gl .float32 ],
515
+ is_pure = True ,
516
+ pack = 1 ,
517
+ )
518
+
519
+
520
+ @gluon .jit
521
+ def _pairwise_fma_f32x2 (a0 , b0 , c0 , a1 , b1 , c1 ):
522
+ return gl .inline_asm_elementwise (
523
+ """
524
+ {
525
+ .reg .b64 rd, ra, rb, rc;
526
+ mov.b64 ra, { $2, $5 };
527
+ mov.b64 rb, { $3, $6 };
528
+ mov.b64 rc, { $4, $7 };
529
+ fma.rn.f32x2 rd, ra, rb, rc;
530
+ mov.b64 { $0, $1 }, rd;
531
+ }
532
+ """ ,
533
+ "=r,=r,r,r,r,r,r,r" ,
534
+ [a0 , b0 , c0 , a1 , b1 , c1 ],
535
+ dtype = [gl .float32 , gl .float32 ],
536
+ is_pure = True ,
537
+ pack = 1 ,
538
+ )
539
+
540
+
473
541
# ===-----------------------------------------------------------------------===#
474
542
# _gluon_attn
475
543
# ===-----------------------------------------------------------------------===#
@@ -500,7 +568,7 @@ def _borrow_s_for_epilogue(config, s_tmem):
500
568
501
569
@gluon .jit
502
570
def _attn_fwd_load (config , chnls , descs , M , STAGE : gl .constexpr ):
503
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
571
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
504
572
desc_q , desc_k , desc_v , desc_o = descs
505
573
506
574
q_producer = q_chnl .create_producer ()
@@ -536,7 +604,7 @@ def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
536
604
537
605
@gluon .jit
538
606
def _attn_fwd_mma (config , chnls , descs , M , STAGE : gl .constexpr ):
539
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
607
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
540
608
desc_q , desc_k , desc_v , desc_o = descs
541
609
542
610
q_consumer = q_chnl .create_consumer ()
@@ -598,8 +666,8 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
598
666
599
667
@gluon .jit
600
668
def _softmax_inner_loop (tile_id : gl .constexpr , config , prog , #
601
- s_consumer , corr_producer , corr_bar , #
602
- offs_m , offs_n , m_i , l_i , STAGE : gl .constexpr ):
669
+ s_consumer , corr_producer , exp_turnstile , corr_bar , #
670
+ offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE : gl .constexpr ):
603
671
lo , hi = prog .get_loop_bounds (STAGE )
604
672
605
673
for start_n in range (lo , hi , config .BLOCK_N ):
@@ -619,31 +687,79 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
619
687
alpha_tmem .store (gl .convert_layout (alpha .expand_dims (1 ), config .alpha_2d_layout ))
620
688
mbarrier .arrive (corr_bar , count = 1 )
621
689
622
- qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
623
- qk = _add_f32x2 (qk , - m_ij [:, None ])
690
+ if config .use_ffma2_scale_rowmax :
691
+ qk = _fma_f32x2 (qk , gl .full_like (qk , config .qk_scale ), - m_ij [:, None ])
692
+ else :
693
+ qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
694
+ qk = _add_f32x2 (qk , - m_ij [:, None ])
624
695
qk0 , qk1 , = qk .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 2 ]).permute (0 , 2 , 1 ).split ()
625
696
626
697
p_tmem = _borrow_s_as_p (config , s_tmem )
627
- p0 = gl .exp2 (qk0 )
628
- p_tmem .slice (0 , config .BLOCK_N // 2 ).store (p0 .to (config .dtype ))
629
- p1 = gl .exp2 (qk1 )
630
- p_tmem .slice (config .BLOCK_N // 2 , config .BLOCK_N // 2 ).store (p1 .to (config .dtype ))
698
+ BN4 : gl .constexpr = config .BLOCK_N // 4
699
+ BN2 : gl .constexpr = config .BLOCK_N // 2
700
+
701
+ # Force the softmax partitions to take turns in the EX2 section. This
702
+ # prevents contention for the EX2 unit and improves utilization.
703
+ if config .use_exp2_turnstile :
704
+ _ , exp_bar , exp_turnstile = exp_turnstile .acquire ()
705
+
706
+ # FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
707
+ # below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
708
+ # 4 to minimize the spilling.
709
+ if config .HEAD_DIM == 64 :
710
+ qk00 , qk01 = qk0 .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 4 ]).permute (0 , 2 , 1 ).split ()
711
+ p00 = gl .exp2 (qk00 )
712
+ p_tmem .slice (0 , BN4 ).store (p00 .to (config .dtype ))
713
+ p01 = gl .exp2 (qk01 )
714
+ p_tmem .slice (BN4 , BN4 ).store (p01 .to (config .dtype ))
715
+ p0 = gl .join (p00 , p01 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N // 2 ])
716
+ p0 = gl .convert_layout (p0 , config .qk_layout )
717
+ else :
718
+ p0 = gl .exp2 (qk0 )
719
+ p_tmem .slice (0 , BN2 ).store (p0 .to (config .dtype ))
720
+
721
+ if config .HEAD_DIM == 64 :
722
+ qk10 , qk11 = qk1 .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 4 ]).permute (0 , 2 , 1 ).split ()
723
+ p10 = gl .exp2 (qk10 )
724
+ p_tmem .slice (2 * BN4 , BN4 ).store (p10 .to (config .dtype ))
725
+ p11 = gl .exp2 (qk11 )
726
+ p_tmem .slice (3 * BN4 , BN4 ).store (p11 .to (config .dtype ))
727
+ p1 = gl .join (p10 , p11 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N // 2 ])
728
+ p1 = gl .convert_layout (p1 , config .qk_layout )
729
+ else :
730
+ p1 = gl .exp2 (qk1 )
731
+ p_tmem .slice (BN2 , BN2 ).store (p1 .to (config .dtype ))
732
+
631
733
mbarrier .arrive (s_bar , count = 1 )
632
734
633
735
_ , corr_bar , corr_producer = corr_producer .acquire ()
634
736
635
- p = gl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N ])
636
- p = gl .convert_layout (p , config .qk_layout )
637
- l_ij = gl .sum (p , axis = 1 )
638
- l_i = l_i * alpha + l_ij
737
+ if config .HEAD_DIM == 64 :
738
+ mbarrier .arrive (exp_bar , count = 1 )
739
+
740
+ if config .use_fadd2_reduce :
741
+ l_ij0 , l_ij1 = gl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
742
+ # This is a difference of 1 SASS instruction but it dramatically
743
+ # affects instruction scheduling.
744
+ if config .dtype == gl .float8e5 :
745
+ l_i0 , l_i1 = _pairwise_fma_f32x2 (l_i0 , alpha , l_ij0 , l_i1 , alpha , l_ij1 )
746
+ else :
747
+ l_i0 = l_i0 * alpha + l_ij0
748
+ l_i1 = l_i1 * alpha + l_ij1
749
+ else :
750
+ p = gl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N ])
751
+ p = gl .convert_layout (p , config .qk_layout )
752
+ l_ij = gl .sum (p , axis = 1 )
753
+ l_i0 = l_i0 * alpha + l_ij
754
+
639
755
m_i = m_ij
640
756
641
- return m_i , l_i , corr_bar , s_consumer , corr_producer
757
+ return m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile
642
758
643
759
644
760
@gluon .jit
645
761
def _softmax_tile (tile_id : gl .constexpr , config , M , desc_o , STAGE : gl .constexpr , #
646
- s_chnl , corr_chnl ):
762
+ s_chnl , corr_chnl , exp_turnstile ):
647
763
qk_slice_dim0 : gl .constexpr = gl .SliceLayout (0 , config .qk_layout )
648
764
qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
649
765
@@ -661,16 +777,26 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
661
777
offs_m += gl .arange (tile_id * config .SPLIT_M , (1 + tile_id ) * config .SPLIT_M , qk_slice_dim1 )
662
778
663
779
m_i = gl .full ([config .SPLIT_M ], - float ("inf" ), gl .float32 , qk_slice_dim1 )
664
- l_i = gl .full ([config .SPLIT_M ], 1.0 , gl .float32 , qk_slice_dim1 )
780
+ l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , qk_slice_dim1 )
781
+ # Accumulate into 2 row-sums so the reduction can be performed with FADD2.
782
+ if config .use_fadd2_reduce :
783
+ l_i1 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , qk_slice_dim1 )
784
+ else :
785
+ l_i1 = 0
665
786
666
787
if STAGE & 1 :
667
- m_i , l_i , corr_bar , s_consumer , corr_producer = _softmax_inner_loop ( #
668
- tile_id , config , prog , s_consumer , corr_producer , corr_bar , #
669
- offs_m , offs_n , m_i , l_i , STAGE = 4 - STAGE )
788
+ m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
789
+ tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
790
+ offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 4 - STAGE )
670
791
if STAGE & 2 :
671
- m_i , l_i , corr_bar , s_consumer , corr_producer = _softmax_inner_loop ( #
672
- tile_id , config , prog , s_consumer , corr_producer , corr_bar , #
673
- offs_m , offs_n , m_i , l_i , STAGE = 2 )
792
+ m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
793
+ tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
794
+ offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 2 )
795
+
796
+ if config .use_fadd2_reduce :
797
+ l_i = l_i0 + l_i1
798
+ else :
799
+ l_i = l_i0
674
800
675
801
s_tmem , s_bar , s_consumer = s_consumer .acquire ()
676
802
m_i_tmem , l_i_tmem = _borrow_s_for_epilogue (config , s_tmem )
@@ -685,21 +811,21 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
685
811
686
812
@gluon .jit
687
813
def _attn_fwd_softmax0 (config , chnls , descs , M , STAGE : gl .constexpr ):
688
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
814
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
689
815
desc_q , desc_k , desc_v , desc_o = descs
690
- _softmax_tile (0 , config , M , desc_o , STAGE , s0_chnl , c0_chnl )
816
+ _softmax_tile (0 , config , M , desc_o , STAGE , s0_chnl , c0_chnl , exp_turnstile . create_producer () )
691
817
692
818
693
819
@gluon .jit
694
820
def _attn_fwd_softmax1 (config , chnls , descs , M , STAGE : gl .constexpr ):
695
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
821
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
696
822
desc_q , desc_k , desc_v , desc_o = descs
697
- _softmax_tile (1 , config , M , desc_o , STAGE , s1_chnl , c1_chnl )
823
+ _softmax_tile (1 , config , M , desc_o , STAGE , s1_chnl , c1_chnl , exp_turnstile . create_consumer () )
698
824
699
825
700
826
@gluon .jit
701
827
def _attn_fwd_epilogue (config , chnls , descs , M , STAGE : gl .constexpr ):
702
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
828
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
703
829
desc_q , desc_k , desc_v , desc_o = descs
704
830
705
831
epi_consumer = epi_chnl .create_consumer ()
@@ -723,12 +849,13 @@ def _attn_fwd_epilogue(config, chnls, descs, M, STAGE: gl.constexpr):
723
849
def _attn_fwd_correction_rescale (config , s_tmem , corr_consumer , o_consumer ):
724
850
alpha_layout : gl .constexpr = gl .SliceLayout (1 , config .o_splitn_layout )
725
851
852
+ o_tmem , o_bar , o_consumer = o_consumer .acquire ()
853
+
726
854
_ , corr_bar , corr_consumer = corr_consumer .acquire ()
727
855
alpha = _borrow_s_as_alpha (config , s_tmem ).load (config .alpha_2d_layout )
728
856
mbarrier .arrive (corr_bar , count = 1 )
729
857
alpha = gl .convert_layout (alpha .reshape ([config .SPLIT_M ]), alpha_layout )
730
858
731
- o_tmem , o_bar , o_consumer = o_consumer .acquire ()
732
859
for i in tl .static_range (config .SPLIT_D_FACTOR ):
733
860
o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
734
861
o = o_ref .load (config .o_splitn_layout )
@@ -753,6 +880,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
753
880
o_smem , epi_bar , epi_producer = epi_producer .acquire ()
754
881
o_tmem , o_bar , o_consumer = o_consumer .acquire ()
755
882
883
+ # Shared memory subtile size is limited by the swizzle byte size.
756
884
contigDimSize : gl .constexpr = o_smem .type .layout .swizzle_byte_width * 8 / o_smem .type .element_ty .primitive_bitwidth
757
885
if o_smem .type .shape [1 ] // config .SPLIT_D_FACTOR >= contigDimSize :
758
886
SPLIT_N_FACTOR : gl .constexpr = config .SPLIT_D_FACTOR
@@ -785,7 +913,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
785
913
786
914
@gluon .jit
787
915
def _attn_fwd_correction (config , chnls , descs , M , STAGE : gl .constexpr ):
788
- q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl = chnls
916
+ q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
789
917
790
918
s0_tmem = s0_chnl .mem .index (0 )
791
919
s1_tmem = s1_chnl .mem .index (0 )
@@ -831,7 +959,7 @@ def attention_kernel( #
831
959
GROUP_SIZE_N : gl .constexpr , NUM_SMS : gl .constexpr , STAGE : gl .constexpr , dtype : gl .constexpr , #
832
960
num_warps : gl .constexpr ):
833
961
qk_scale = sm_scale * 1.44269504
834
- config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , # i
962
+ config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , STAGE , #
835
963
dtype , num_warps , SPLIT_D_FACTOR = 2 )
836
964
837
965
q_chnl = get_desc_channel (desc_q , num_buffers = 2 )
@@ -842,8 +970,9 @@ def attention_kernel( #
842
970
s1_chnl = TensorMemoryChannel .alloc (config .qk_shape , gl .float32 , config .qk_tmem_layout , num_buffers = 1 )
843
971
c0_chnl = SharedMemoryChannel .alloc ([1 ], gl .int8 , gl .constexpr (mbarrier .MBarrierLayout ()), num_buffers = 1 )
844
972
c1_chnl = SharedMemoryChannel .alloc ([1 ], gl .int8 , gl .constexpr (mbarrier .MBarrierLayout ()), num_buffers = 1 )
973
+ exp_turnstile = SharedMemoryChannel .alloc ([1 ], gl .int8 , gl .constexpr (mbarrier .MBarrierLayout ()), num_buffers = 1 )
845
974
846
- chnls = (q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl )
975
+ chnls = (q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile )
847
976
descs = (desc_q , desc_k , desc_v , desc_o )
848
977
gl .warp_specialize ((config , chnls , descs , M , STAGE ), _attn_fwd_correction , (config , chnls , descs , M , STAGE ), [
849
978
_attn_fwd_softmax0 ,
@@ -861,6 +990,7 @@ def attention_kernel( #
861
990
s1_chnl .release ()
862
991
c0_chnl .release ()
863
992
c1_chnl .release ()
993
+ exp_turnstile .release ()
864
994
865
995
866
996
# ===-----------------------------------------------------------------------===#
@@ -938,7 +1068,7 @@ def is_blackwell():
938
1068
@pytest .mark .parametrize ("causal" , [False , True ])
939
1069
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
940
1070
@pytest .mark .skipif (not is_blackwell (), reason = "Gluon attention is only supported on Blackwell GPUs" )
941
- def test_op (Z , H , N_CTX , HEAD_DIM , causal , dtype ):
1071
+ def test_op (Z , H , N_CTX , HEAD_DIM , causal , dtype , profile = False ):
942
1072
device = "cuda"
943
1073
944
1074
torch .manual_seed (42 )
@@ -961,7 +1091,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
961
1091
N_HEADS = [32 ]
962
1092
HEAD_DIM = [64 , 128 ]
963
1093
causal = [False , True ]
964
- providers = ["triton-fp16" , "triton-bf16" , "triton- fp8" , "cudnn-fp16" , "cudnn-bf16 " ]
1094
+ providers = ["triton-fp16" , "triton-fp8" ]
965
1095
N_CTX = [2 ** i for i in range (10 , 17 )]
966
1096
967
1097
bench_configs = []
0 commit comments