@@ -334,6 +334,8 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
334334
335335 if dtype == gl .float16 :
336336 self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
337+ elif dtype == gl .bfloat16 :
338+ self .num_kv_buffers = gl .constexpr (3 if HEAD_DIM == 128 else 6 )
337339 else :
338340 self .num_kv_buffers = gl .constexpr (4 if HEAD_DIM == 128 else 8 )
339341
@@ -934,7 +936,7 @@ def is_blackwell():
934936@pytest .mark .parametrize ("N_CTX" , [256 , 1024 , 4 * 1024 ])
935937@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
936938@pytest .mark .parametrize ("causal" , [False , True ])
937- @pytest .mark .parametrize ("dtype" , [torch .float16 ])
939+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch . bfloat16 ])
938940@pytest .mark .skipif (not is_blackwell (), reason = "Gluon attention is only supported on Blackwell GPUs" )
939941def test_op (Z , H , N_CTX , HEAD_DIM , causal , dtype ):
940942 device = "cuda"
@@ -945,12 +947,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
945947 v = (torch .empty ((Z , H , N_CTX , HEAD_DIM ), dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 ).requires_grad_ ())
946948 sm_scale = 0.5
947949
948- M = torch .tril (torch .ones ((N_CTX , N_CTX ), device = device ))
949- p = torch .matmul (q , k .transpose (2 , 3 )) * sm_scale
950- if causal :
951- p [:, :, M == 0 ] = float ("-inf" )
952- p = torch .softmax (p .float (), dim = - 1 ).half ()
953- ref_out = torch .matmul (p , v )
950+ ref_out = torch .nn .functional .scaled_dot_product_attention (q , k , v , scale = sm_scale , is_causal = causal )
954951
955952 tri_out , _ = attention_forward (q , k , v , causal , sm_scale )
956953 torch .testing .assert_close (ref_out , tri_out , atol = 1e-2 , rtol = 0 )
@@ -964,7 +961,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
964961N_HEADS = [32 ]
965962HEAD_DIM = [64 , 128 ]
966963causal = [False , True ]
967- providers = ["triton-fp16" , "triton-fp8" , "cudnn-fp16" ]
964+ providers = ["triton-fp16" , "triton-bf16" , "triton- fp8" , "cudnn-fp16" , "cudnn-bf16 " ]
968965N_CTX = [2 ** i for i in range (10 , 17 )]
969966
970967bench_configs = []
@@ -993,6 +990,8 @@ def bench(Z, H, N_CTX, HEAD_DIM, causal, provider):
993990 provider , dtype = provider .split ("-" )
994991 if dtype == "fp16" :
995992 dtype = torch .float16
993+ elif dtype == "bf16" :
994+ dtype = torch .bfloat16
996995 elif dtype == "fp8" :
997996 dtype = torch .float8_e5m2
998997 else :
0 commit comments