Skip to content

Commit de9309d

Browse files
authored
[Gluon][Tutorial] More attention optimizations (#7488)
* Implement a turnstile over `gl.exp2` to control access to MFU for D64 to reduce contention between both softmax partitions * For D64, use FADD2 to compute the row_sum by keeping 2 separate sums * This triggers spilling in the inner loop for no apparent reason, but subtiling by 4 appears to generate a spilling pattern that tiles the L1 cache, reducing the performance impact * Switch to FFMA2 for `qk` scale and row_max subtract Altogether, D64 is 100-120 TFLOPS faster across the board. fp16 D128 is about 50 TFLOPS faster, and causal D128 fp8 is up to 200 TFLOPS faster From here, there is line-of-sight to cuDNN-level performance at least for non-causal D64 fp16 and both causal and non-causal D128 fp8. Causal masking for fp16 appears to have a large cost ``` Attention Z=4 H=32 D=64 causal=False: N_CTX triton-fp16 triton-fp8 0 1024.0 338.201812 360.876850 1 2048.0 678.613324 695.699051 2 4096.0 745.084460 744.836786 3 8192.0 798.156543 789.420950 4 16384.0 806.820468 795.589407 5 32768.0 811.021537 798.095623 6 65536.0 807.884668 799.765676 Attention Z=4 H=32 D=64 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 180.663904 183.450252 1 2048.0 448.904426 445.195799 2 4096.0 572.631109 568.819476 3 8192.0 680.352270 677.308814 4 16384.0 721.535576 718.793844 5 32768.0 748.864730 753.577144 6 65536.0 756.589109 765.344021 Attention Z=4 H=32 D=128 causal=False: N_CTX triton-fp16 triton-fp8 0 1024.0 708.799066 740.206893 1 2048.0 1149.176236 1199.307587 2 4096.0 1211.860275 1379.656870 3 8192.0 1261.080996 1417.870902 4 16384.0 1231.603976 1562.275662 5 32768.0 1210.829899 1580.506835 6 65536.0 1213.299696 1566.246921 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 361.914022 373.335973 1 2048.0 775.257119 838.468061 2 4096.0 989.752169 1130.565960 3 8192.0 1132.501215 1356.102973 4 16384.0 1147.953291 1490.517324 5 32768.0 1166.571190 1597.625819 6 65536.0 1130.217345 1621.162896 ```
1 parent 732c0db commit de9309d

File tree

1 file changed

+169
-39
lines changed

1 file changed

+169
-39
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 169 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,12 @@ class AttentionConfig:
293293
alpha_2d_layout: gl.constexpr
294294

295295
num_kv_buffers: gl.constexpr
296+
use_fadd2_reduce: gl.constexpr
297+
use_exp2_turnstile: gl.constexpr
298+
use_ffma2_scale_rowmax: gl.constexpr
296299

297-
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, dtype, num_warps,
298-
SPLIT_D_FACTOR):
300+
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype,
301+
num_warps, SPLIT_D_FACTOR):
299302
self.qk_scale = qk_scale
300303
self.Z = Z
301304
self.H = H
@@ -332,13 +335,16 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
332335
(self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR), self.num_warps))
333336
self.alpha_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1]))
334337

335-
if dtype == gl.float16:
336-
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
337-
elif dtype == gl.bfloat16:
338+
is_fp16 = dtype.value in [gl.float16, gl.bfloat16]
339+
if is_fp16:
338340
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
339341
else:
340342
self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8)
341343

344+
self.use_fadd2_reduce = gl.constexpr(HEAD_DIM == 64)
345+
self.use_exp2_turnstile = gl.constexpr(HEAD_DIM == 64)
346+
self.use_ffma2_scale_rowmax = gl.constexpr(HEAD_DIM == 128 or is_fp16 == (STAGE == 3))
347+
342348
@gluon.jit
343349
def get_program(self, pid_m, pid_n):
344350
start_m = pid_m
@@ -470,6 +476,68 @@ def _mul_f32x2(a, b):
470476
)
471477

472478

479+
@gluon.jit
480+
def _fma_f32x2(a, b, c):
481+
return gl.inline_asm_elementwise(
482+
"""
483+
{
484+
.reg .b64 ra, rb, rc, rd;
485+
mov.b64 ra, { $2, $3 };
486+
mov.b64 rb, { $4, $5 };
487+
mov.b64 rc, { $6, $7 };
488+
fma.rn.f32x2 rd, ra, rb, rc;
489+
mov.b64 { $0, $1 }, rd;
490+
}
491+
""",
492+
"=r,=r,r,r,r,r,r,r",
493+
[a, b, c],
494+
dtype=gl.float32,
495+
is_pure=True,
496+
pack=2,
497+
)
498+
499+
500+
@gluon.jit
501+
def _reduce_fadd2(p0a, p1a, p0b, p1b):
502+
return gl.inline_asm_elementwise(
503+
"""
504+
{
505+
.reg .b64 rc, ra, rb;
506+
mov.b64 ra, { $2, $4 };
507+
mov.b64 rb, { $3, $5 };
508+
add.f32x2 rc, ra, rb;
509+
mov.b64 { $0, $1 }, rc;
510+
}
511+
""",
512+
"=r,=r,r,r,r,r",
513+
[p0a, p0b, p1a, p1b],
514+
dtype=[gl.float32, gl.float32],
515+
is_pure=True,
516+
pack=1,
517+
)
518+
519+
520+
@gluon.jit
521+
def _pairwise_fma_f32x2(a0, b0, c0, a1, b1, c1):
522+
return gl.inline_asm_elementwise(
523+
"""
524+
{
525+
.reg .b64 rd, ra, rb, rc;
526+
mov.b64 ra, { $2, $5 };
527+
mov.b64 rb, { $3, $6 };
528+
mov.b64 rc, { $4, $7 };
529+
fma.rn.f32x2 rd, ra, rb, rc;
530+
mov.b64 { $0, $1 }, rd;
531+
}
532+
""",
533+
"=r,=r,r,r,r,r,r,r",
534+
[a0, b0, c0, a1, b1, c1],
535+
dtype=[gl.float32, gl.float32],
536+
is_pure=True,
537+
pack=1,
538+
)
539+
540+
473541
# ===-----------------------------------------------------------------------===#
474542
# _gluon_attn
475543
# ===-----------------------------------------------------------------------===#
@@ -500,7 +568,7 @@ def _borrow_s_for_epilogue(config, s_tmem):
500568

501569
@gluon.jit
502570
def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
503-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
571+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
504572
desc_q, desc_k, desc_v, desc_o = descs
505573

506574
q_producer = q_chnl.create_producer()
@@ -536,7 +604,7 @@ def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
536604

537605
@gluon.jit
538606
def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
539-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
607+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
540608
desc_q, desc_k, desc_v, desc_o = descs
541609

542610
q_consumer = q_chnl.create_consumer()
@@ -598,8 +666,8 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
598666

599667
@gluon.jit
600668
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
601-
s_consumer, corr_producer, corr_bar, #
602-
offs_m, offs_n, m_i, l_i, STAGE: gl.constexpr):
669+
s_consumer, corr_producer, exp_turnstile, corr_bar, #
670+
offs_m, offs_n, m_i, l_i0, l_i1, STAGE: gl.constexpr):
603671
lo, hi = prog.get_loop_bounds(STAGE)
604672

605673
for start_n in range(lo, hi, config.BLOCK_N):
@@ -619,31 +687,79 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
619687
alpha_tmem.store(gl.convert_layout(alpha.expand_dims(1), config.alpha_2d_layout))
620688
mbarrier.arrive(corr_bar, count=1)
621689

622-
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
623-
qk = _add_f32x2(qk, -m_ij[:, None])
690+
if config.use_ffma2_scale_rowmax:
691+
qk = _fma_f32x2(qk, gl.full_like(qk, config.qk_scale), -m_ij[:, None])
692+
else:
693+
qk = _mul_f32x2(qk, gl.full_like(qk, config.qk_scale))
694+
qk = _add_f32x2(qk, -m_ij[:, None])
624695
qk0, qk1, = qk.reshape([config.SPLIT_M, 2, config.BLOCK_N // 2]).permute(0, 2, 1).split()
625696

626697
p_tmem = _borrow_s_as_p(config, s_tmem)
627-
p0 = gl.exp2(qk0)
628-
p_tmem.slice(0, config.BLOCK_N // 2).store(p0.to(config.dtype))
629-
p1 = gl.exp2(qk1)
630-
p_tmem.slice(config.BLOCK_N // 2, config.BLOCK_N // 2).store(p1.to(config.dtype))
698+
BN4: gl.constexpr = config.BLOCK_N // 4
699+
BN2: gl.constexpr = config.BLOCK_N // 2
700+
701+
# Force the softmax partitions to take turns in the EX2 section. This
702+
# prevents contention for the EX2 unit and improves utilization.
703+
if config.use_exp2_turnstile:
704+
_, exp_bar, exp_turnstile = exp_turnstile.acquire()
705+
706+
# FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
707+
# below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
708+
# 4 to minimize the spilling.
709+
if config.HEAD_DIM == 64:
710+
qk00, qk01 = qk0.reshape([config.SPLIT_M, 2, config.BLOCK_N // 4]).permute(0, 2, 1).split()
711+
p00 = gl.exp2(qk00)
712+
p_tmem.slice(0, BN4).store(p00.to(config.dtype))
713+
p01 = gl.exp2(qk01)
714+
p_tmem.slice(BN4, BN4).store(p01.to(config.dtype))
715+
p0 = gl.join(p00, p01).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N // 2])
716+
p0 = gl.convert_layout(p0, config.qk_layout)
717+
else:
718+
p0 = gl.exp2(qk0)
719+
p_tmem.slice(0, BN2).store(p0.to(config.dtype))
720+
721+
if config.HEAD_DIM == 64:
722+
qk10, qk11 = qk1.reshape([config.SPLIT_M, 2, config.BLOCK_N // 4]).permute(0, 2, 1).split()
723+
p10 = gl.exp2(qk10)
724+
p_tmem.slice(2 * BN4, BN4).store(p10.to(config.dtype))
725+
p11 = gl.exp2(qk11)
726+
p_tmem.slice(3 * BN4, BN4).store(p11.to(config.dtype))
727+
p1 = gl.join(p10, p11).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N // 2])
728+
p1 = gl.convert_layout(p1, config.qk_layout)
729+
else:
730+
p1 = gl.exp2(qk1)
731+
p_tmem.slice(BN2, BN2).store(p1.to(config.dtype))
732+
631733
mbarrier.arrive(s_bar, count=1)
632734

633735
_, corr_bar, corr_producer = corr_producer.acquire()
634736

635-
p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N])
636-
p = gl.convert_layout(p, config.qk_layout)
637-
l_ij = gl.sum(p, axis=1)
638-
l_i = l_i * alpha + l_ij
737+
if config.HEAD_DIM == 64:
738+
mbarrier.arrive(exp_bar, count=1)
739+
740+
if config.use_fadd2_reduce:
741+
l_ij0, l_ij1 = gl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
742+
# This is a difference of 1 SASS instruction but it dramatically
743+
# affects instruction scheduling.
744+
if config.dtype == gl.float8e5:
745+
l_i0, l_i1 = _pairwise_fma_f32x2(l_i0, alpha, l_ij0, l_i1, alpha, l_ij1)
746+
else:
747+
l_i0 = l_i0 * alpha + l_ij0
748+
l_i1 = l_i1 * alpha + l_ij1
749+
else:
750+
p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N])
751+
p = gl.convert_layout(p, config.qk_layout)
752+
l_ij = gl.sum(p, axis=1)
753+
l_i0 = l_i0 * alpha + l_ij
754+
639755
m_i = m_ij
640756

641-
return m_i, l_i, corr_bar, s_consumer, corr_producer
757+
return m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile
642758

643759

644760
@gluon.jit
645761
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr, #
646-
s_chnl, corr_chnl):
762+
s_chnl, corr_chnl, exp_turnstile):
647763
qk_slice_dim0: gl.constexpr = gl.SliceLayout(0, config.qk_layout)
648764
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
649765

@@ -661,16 +777,26 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
661777
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M, qk_slice_dim1)
662778

663779
m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
664-
l_i = gl.full([config.SPLIT_M], 1.0, gl.float32, qk_slice_dim1)
780+
l_i0 = gl.full([config.SPLIT_M], 0.0, gl.float32, qk_slice_dim1)
781+
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
782+
if config.use_fadd2_reduce:
783+
l_i1 = gl.full([config.SPLIT_M], 0.0, gl.float32, qk_slice_dim1)
784+
else:
785+
l_i1 = 0
665786

666787
if STAGE & 1:
667-
m_i, l_i, corr_bar, s_consumer, corr_producer = _softmax_inner_loop( #
668-
tile_id, config, prog, s_consumer, corr_producer, corr_bar, #
669-
offs_m, offs_n, m_i, l_i, STAGE=4 - STAGE)
788+
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
789+
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
790+
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=4 - STAGE)
670791
if STAGE & 2:
671-
m_i, l_i, corr_bar, s_consumer, corr_producer = _softmax_inner_loop( #
672-
tile_id, config, prog, s_consumer, corr_producer, corr_bar, #
673-
offs_m, offs_n, m_i, l_i, STAGE=2)
792+
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
793+
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
794+
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=2)
795+
796+
if config.use_fadd2_reduce:
797+
l_i = l_i0 + l_i1
798+
else:
799+
l_i = l_i0
674800

675801
s_tmem, s_bar, s_consumer = s_consumer.acquire()
676802
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
@@ -685,21 +811,21 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
685811

686812
@gluon.jit
687813
def _attn_fwd_softmax0(config, chnls, descs, M, STAGE: gl.constexpr):
688-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
814+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
689815
desc_q, desc_k, desc_v, desc_o = descs
690-
_softmax_tile(0, config, M, desc_o, STAGE, s0_chnl, c0_chnl)
816+
_softmax_tile(0, config, M, desc_o, STAGE, s0_chnl, c0_chnl, exp_turnstile.create_producer())
691817

692818

693819
@gluon.jit
694820
def _attn_fwd_softmax1(config, chnls, descs, M, STAGE: gl.constexpr):
695-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
821+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
696822
desc_q, desc_k, desc_v, desc_o = descs
697-
_softmax_tile(1, config, M, desc_o, STAGE, s1_chnl, c1_chnl)
823+
_softmax_tile(1, config, M, desc_o, STAGE, s1_chnl, c1_chnl, exp_turnstile.create_consumer())
698824

699825

700826
@gluon.jit
701827
def _attn_fwd_epilogue(config, chnls, descs, M, STAGE: gl.constexpr):
702-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
828+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
703829
desc_q, desc_k, desc_v, desc_o = descs
704830

705831
epi_consumer = epi_chnl.create_consumer()
@@ -723,12 +849,13 @@ def _attn_fwd_epilogue(config, chnls, descs, M, STAGE: gl.constexpr):
723849
def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
724850
alpha_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)
725851

852+
o_tmem, o_bar, o_consumer = o_consumer.acquire()
853+
726854
_, corr_bar, corr_consumer = corr_consumer.acquire()
727855
alpha = _borrow_s_as_alpha(config, s_tmem).load(config.alpha_2d_layout)
728856
mbarrier.arrive(corr_bar, count=1)
729857
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
730858

731-
o_tmem, o_bar, o_consumer = o_consumer.acquire()
732859
for i in tl.static_range(config.SPLIT_D_FACTOR):
733860
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
734861
o = o_ref.load(config.o_splitn_layout)
@@ -753,6 +880,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
753880
o_smem, epi_bar, epi_producer = epi_producer.acquire()
754881
o_tmem, o_bar, o_consumer = o_consumer.acquire()
755882

883+
# Shared memory subtile size is limited by the swizzle byte size.
756884
contigDimSize: gl.constexpr = o_smem.type.layout.swizzle_byte_width * 8 / o_smem.type.element_ty.primitive_bitwidth
757885
if o_smem.type.shape[1] // config.SPLIT_D_FACTOR >= contigDimSize:
758886
SPLIT_N_FACTOR: gl.constexpr = config.SPLIT_D_FACTOR
@@ -785,7 +913,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
785913

786914
@gluon.jit
787915
def _attn_fwd_correction(config, chnls, descs, M, STAGE: gl.constexpr):
788-
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl = chnls
916+
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
789917

790918
s0_tmem = s0_chnl.mem.index(0)
791919
s1_tmem = s1_chnl.mem.index(0)
@@ -831,7 +959,7 @@ def attention_kernel( #
831959
GROUP_SIZE_N: gl.constexpr, NUM_SMS: gl.constexpr, STAGE: gl.constexpr, dtype: gl.constexpr, #
832960
num_warps: gl.constexpr):
833961
qk_scale = sm_scale * 1.44269504
834-
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, # i
962+
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, #
835963
dtype, num_warps, SPLIT_D_FACTOR=2)
836964

837965
q_chnl = get_desc_channel(desc_q, num_buffers=2)
@@ -842,8 +970,9 @@ def attention_kernel( #
842970
s1_chnl = TensorMemoryChannel.alloc(config.qk_shape, gl.float32, config.qk_tmem_layout, num_buffers=1)
843971
c0_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
844972
c1_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
973+
exp_turnstile = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
845974

846-
chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl)
975+
chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
847976
descs = (desc_q, desc_k, desc_v, desc_o)
848977
gl.warp_specialize((config, chnls, descs, M, STAGE), _attn_fwd_correction, (config, chnls, descs, M, STAGE), [
849978
_attn_fwd_softmax0,
@@ -861,6 +990,7 @@ def attention_kernel( #
861990
s1_chnl.release()
862991
c0_chnl.release()
863992
c1_chnl.release()
993+
exp_turnstile.release()
864994

865995

866996
# ===-----------------------------------------------------------------------===#
@@ -938,7 +1068,7 @@ def is_blackwell():
9381068
@pytest.mark.parametrize("causal", [False, True])
9391069
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
9401070
@pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs")
941-
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
1071+
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, profile=False):
9421072
device = "cuda"
9431073

9441074
torch.manual_seed(42)
@@ -961,7 +1091,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
9611091
N_HEADS = [32]
9621092
HEAD_DIM = [64, 128]
9631093
causal = [False, True]
964-
providers = ["triton-fp16", "triton-bf16", "triton-fp8", "cudnn-fp16", "cudnn-bf16"]
1094+
providers = ["triton-fp16", "triton-fp8"]
9651095
N_CTX = [2**i for i in range(10, 17)]
9661096

9671097
bench_configs = []

0 commit comments

Comments
 (0)