Skip to content

Commit 0bd996b

Browse files
authored
Enable bf16 for Gluon/FA (#7445)
1 parent 9fcb4b9 commit 0bd996b

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
334334

335335
if dtype == gl.float16:
336336
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
337+
elif dtype == gl.bfloat16:
338+
self.num_kv_buffers = gl.constexpr(3 if HEAD_DIM == 128 else 6)
337339
else:
338340
self.num_kv_buffers = gl.constexpr(4 if HEAD_DIM == 128 else 8)
339341

@@ -934,7 +936,7 @@ def is_blackwell():
934936
@pytest.mark.parametrize("N_CTX", [256, 1024, 4 * 1024])
935937
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
936938
@pytest.mark.parametrize("causal", [False, True])
937-
@pytest.mark.parametrize("dtype", [torch.float16])
939+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
938940
@pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs")
939941
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
940942
device = "cuda"
@@ -945,12 +947,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
945947
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
946948
sm_scale = 0.5
947949

948-
M = torch.tril(torch.ones((N_CTX, N_CTX), device=device))
949-
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
950-
if causal:
951-
p[:, :, M == 0] = float("-inf")
952-
p = torch.softmax(p.float(), dim=-1).half()
953-
ref_out = torch.matmul(p, v)
950+
ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal)
954951

955952
tri_out, _ = attention_forward(q, k, v, causal, sm_scale)
956953
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
@@ -964,7 +961,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):
964961
N_HEADS = [32]
965962
HEAD_DIM = [64, 128]
966963
causal = [False, True]
967-
providers = ["triton-fp16", "triton-fp8", "cudnn-fp16"]
964+
providers = ["triton-fp16", "triton-bf16", "triton-fp8", "cudnn-fp16", "cudnn-bf16"]
968965
N_CTX = [2**i for i in range(10, 17)]
969966

970967
bench_configs = []
@@ -993,6 +990,8 @@ def bench(Z, H, N_CTX, HEAD_DIM, causal, provider):
993990
provider, dtype = provider.split("-")
994991
if dtype == "fp16":
995992
dtype = torch.float16
993+
elif dtype == "bf16":
994+
dtype = torch.bfloat16
996995
elif dtype == "fp8":
997996
dtype = torch.float8_e5m2
998997
else:

0 commit comments

Comments
 (0)