Skip to content

Commit 75d27b0

Browse files
[Tutorial] Fix 06-fused-attention.py of FP8 provider (#7043)
When the provider is `fp8`, `v` is permuted like below, and the new stride is `(H*N_CTX*HEAD_DIM, N_CTX*HEAD_DIM, 1, N_CTX)`. ``` if mode == "fwd" and "fp8" in provider: v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) ``` This PR fixes the FP8 dtype handling in the fused-attention kernel by separating `k` and `v` offset calculations and updating related configuration details. Key changes include: - Renaming and separating offset variables for `k` and `v` computations. - Adjusting offset calculation for FP8 dtype and updating the tensor descriptor creation. - Expanding configuration options for BLOCK_N and refining device-specific configuration conditions. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 1f12637 commit 75d27b0

File tree

1 file changed

+69
-22
lines changed

1 file changed

+69
-22
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
5656
# causal = False
5757
else:
5858
lo, hi = 0, N_CTX
59-
offsetkv_y = offset_y + lo
59+
offsetk_y = offset_y + lo
60+
if dtype == tl.float8e5:
61+
offsetv_y = offset_y * HEAD_DIM + lo
62+
else:
63+
offsetv_y = offset_y + lo
6064
# loop over k, v and update accumulator
6165
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
6266
start_n = tl.multiple_of(start_n, BLOCK_N)
6367
# -- compute qk ----
64-
k = desc_k.load([offsetkv_y, 0]).T
68+
k = desc_k.load([offsetk_y, 0]).T
6569
qk = tl.dot(q, k)
6670
if STAGE == 2:
6771
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -86,15 +90,19 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8690
else:
8791
acc = acc * alpha[:, None]
8892
# prepare p and v for the dot
89-
v = desc_v.load([offsetkv_y, 0])
93+
if dtype == tl.float8e5:
94+
v = desc_v.load([0, offsetv_y]).T
95+
else:
96+
v = desc_v.load([offsetv_y, 0])
9097
p = p.to(dtype)
9198
# note that this non transposed v for FP8 is only supported on Blackwell
9299
acc = tl.dot(p, v, acc)
93100
# update m_i and l_i
94101
# place this at the end of the loop to reduce register pressure
95102
l_i = l_i * alpha + l_ij
96103
m_i = m_ij
97-
offsetkv_y += BLOCK_N
104+
offsetk_y += BLOCK_N
105+
offsetv_y += BLOCK_N
98106
return acc, l_i, m_i
99107

100108

@@ -105,7 +113,10 @@ def _host_descriptor_pre_hook(nargs):
105113
if not isinstance(nargs["desc_q"], TensorDescriptor):
106114
return
107115
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
108-
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
116+
if nargs["FP8_OUTPUT"]:
117+
nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N]
118+
else:
119+
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
109120
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
110121
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
111122

@@ -120,7 +131,7 @@ def _host_descriptor_pre_hook(nargs):
120131
configs = [
121132
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
122133
for BM in [64, 128]\
123-
for BN in [64, 128]\
134+
for BN in [32, 64, 128]\
124135
for s in NUM_STAGES_OPTIONS \
125136
for w in [4, 8]\
126137
]
@@ -134,7 +145,8 @@ def _host_descriptor_pre_hook(nargs):
134145
def keep(conf):
135146
BLOCK_M = conf.kwargs["BLOCK_M"]
136147
BLOCK_N = conf.kwargs["BLOCK_N"]
137-
return not (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
148+
return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128
149+
and conf.num_warps == 8)
138150

139151

140152
def prune_invalid_configs(configs, named_args, **kwargs):
@@ -174,8 +186,12 @@ def _attn_fwd(sm_scale, M, #
174186
y_dim = Z * H * N_CTX
175187
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
176188
block_shape=[BLOCK_M, HEAD_DIM])
177-
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
178-
block_shape=[BLOCK_N, HEAD_DIM])
189+
if FP8_OUTPUT:
190+
desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim], strides=[N_CTX, 1],
191+
block_shape=[HEAD_DIM, BLOCK_N])
192+
else:
193+
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
194+
block_shape=[BLOCK_N, HEAD_DIM])
179195
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
180196
block_shape=[BLOCK_N, HEAD_DIM])
181197
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
@@ -494,7 +510,12 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
494510

495511
dummy_block = [1, 1]
496512
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
497-
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
513+
if q.dtype == torch.float8_e5m2:
514+
desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1],
515+
block_shape=dummy_block)
516+
else:
517+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1],
518+
block_shape=dummy_block)
498519
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
499520
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
500521
else:
@@ -579,48 +600,74 @@ def backward(ctx, do):
579600

580601
attention = _attention.apply
581602

603+
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
604+
582605

583606
@pytest.mark.parametrize("Z", [1, 4])
584607
@pytest.mark.parametrize("H", [2, 48])
585608
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
586609
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
587610
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
588611
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
589-
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16):
612+
@pytest.mark.parametrize("mode", ["fwd", "bwd"])
613+
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []))
614+
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16):
615+
if mode == "fwd" and "fp16" in provider:
616+
pytest.skip("Avoid running the forward computation twice.")
617+
if mode == "bwd" and "fp8" in provider:
618+
pytest.skip("Backward pass with FP8 is not supported.")
590619
torch.manual_seed(20)
591620
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
592621
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
593622
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
594623
sm_scale = 0.5
595-
dout = torch.randn_like(q)
596624
# reference implementation
625+
ref_dtype = dtype
626+
if mode == "fwd" and "fp8" in provider:
627+
ref_dtype = torch.float32
628+
q = q.to(ref_dtype)
629+
k = k.to(ref_dtype)
630+
v = v.to(ref_dtype)
597631
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
598632
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
599633
if causal:
600634
p[:, :, M == 0] = float("-inf")
601-
p = torch.softmax(p.float(), dim=-1).half()
635+
p = torch.softmax(p.float(), dim=-1)
636+
p = p.to(ref_dtype)
602637
# p = torch.exp(p)
603-
ref_out = torch.matmul(p, v)
604-
ref_out.backward(dout)
605-
ref_dv, v.grad = v.grad.clone(), None
606-
ref_dk, k.grad = k.grad.clone(), None
607-
ref_dq, q.grad = q.grad.clone(), None
638+
ref_out = torch.matmul(p, v).half()
639+
if mode == "bwd":
640+
dout = torch.randn_like(q)
641+
ref_out.backward(dout)
642+
ref_dv, v.grad = v.grad.clone(), None
643+
ref_dk, k.grad = k.grad.clone(), None
644+
ref_dq, q.grad = q.grad.clone(), None
608645
# triton implementation
646+
if mode == "fwd" and "fp8" in provider:
647+
q = q.to(torch.float8_e5m2)
648+
k = k.to(torch.float8_e5m2)
649+
v = v.permute(0, 1, 3, 2).contiguous()
650+
v = v.permute(0, 1, 3, 2)
651+
v = v.to(torch.float8_e5m2)
609652
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half()
653+
if mode == "fwd":
654+
atol = 3 if "fp8" in provider else 1e-2
655+
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
656+
return
610657
tri_out.backward(dout)
611658
tri_dv, v.grad = v.grad.clone(), None
612659
tri_dk, k.grad = k.grad.clone(), None
613660
tri_dq, q.grad = q.grad.clone(), None
614661
# compare
615-
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
662+
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0)
616663
rtol = 0.0
617664
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
618665
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
619666
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
620667
rtol = 1e-2
621-
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
622-
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
623-
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
668+
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
669+
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
670+
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
624671

625672

626673
try:

0 commit comments

Comments
 (0)