Skip to content

Commit 2e94e10

Browse files
author
pytorchbot
committed
2026-03-23 nightly release (7abd4b0)
1 parent 3e4afc1 commit 2e94e10

File tree

2 files changed

+275
-0
lines changed

2 files changed

+275
-0
lines changed

test/attention/bench_ck_fa4.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-unsafe
7+
8+
"""Benchmark CK FMHA with FA4 conditional rescaling."""
9+
10+
import math
11+
12+
import torch
13+
from mslk.attention import fmha
14+
15+
16+
def bench(B, H, M, K, dtype=torch.bfloat16, n_warmup=10, n_iters=50):
17+
q = torch.randn(B, M, H, K, device="cuda", dtype=dtype)
18+
k = torch.randn(B, M, H, K, device="cuda", dtype=dtype)
19+
v = torch.randn(B, M, H, K, device="cuda", dtype=dtype)
20+
21+
# Correctness check
22+
out = fmha.memory_efficient_attention_forward(
23+
q, k, v, attn_bias=None, op=fmha.ck.FwOp
24+
)
25+
assert not out.isnan().any(), "NaN in output!"
26+
27+
# Reference check
28+
q_t = q.float().transpose(1, 2)
29+
k_t = k.float().transpose(1, 2)
30+
v_t = v.float().transpose(1, 2)
31+
ref = torch.matmul(
32+
torch.softmax(torch.matmul(q_t, k_t.transpose(-2, -1)) / math.sqrt(K), dim=-1),
33+
v_t,
34+
).transpose(1, 2)
35+
max_diff = (out.float() - ref).abs().max().item()
36+
mean_diff = (out.float() - ref).abs().mean().item()
37+
38+
# Warmup
39+
for _ in range(n_warmup):
40+
fmha.memory_efficient_attention_forward(
41+
q, k, v, attn_bias=None, op=fmha.ck.FwOp
42+
)
43+
torch.cuda.synchronize()
44+
45+
# Benchmark
46+
start = torch.cuda.Event(enable_timing=True)
47+
end = torch.cuda.Event(enable_timing=True)
48+
start.record()
49+
for _ in range(n_iters):
50+
fmha.memory_efficient_attention_forward(
51+
q, k, v, attn_bias=None, op=fmha.ck.FwOp
52+
)
53+
end.record()
54+
torch.cuda.synchronize()
55+
56+
elapsed_ms = start.elapsed_time(end) / n_iters
57+
flops = 4 * B * H * M * M * K
58+
tflops = flops / (elapsed_ms * 1e-3) / 1e12
59+
60+
print(
61+
f" B={B:2d} H={H:2d} M={M:5d} K={K:3d} {str(dtype):20s} | "
62+
f"{elapsed_ms:8.2f} ms {tflops:6.1f} TFLOPS | "
63+
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}"
64+
)
65+
66+
67+
def main():
68+
print("CK FMHA FA4 Benchmark (MI350x)")
69+
print("=" * 100)
70+
print(f" {'Shape':50s} | {'Time':8s} {'Perf':6s} | {'Accuracy':30s}")
71+
print("-" * 100)
72+
73+
# Typical Flux2 shapes
74+
bench(B=1, H=24, M=4096, K=128, dtype=torch.bfloat16)
75+
bench(B=1, H=24, M=4096, K=128, dtype=torch.float16)
76+
77+
# Various sequence lengths
78+
bench(B=4, H=32, M=1024, K=128, dtype=torch.bfloat16)
79+
bench(B=4, H=32, M=2048, K=128, dtype=torch.bfloat16)
80+
bench(B=4, H=32, M=4096, K=128, dtype=torch.bfloat16)
81+
82+
# Different head dims
83+
bench(B=2, H=8, M=2048, K=64, dtype=torch.bfloat16)
84+
bench(B=2, H=8, M=2048, K=128, dtype=torch.bfloat16)
85+
bench(B=2, H=8, M=2048, K=256, dtype=torch.bfloat16)
86+
87+
# Small batch, many heads (GPT-like)
88+
bench(B=1, H=64, M=2048, K=64, dtype=torch.bfloat16)
89+
90+
print("=" * 100)
91+
print("All benchmarks passed correctness check!")
92+
93+
94+
if __name__ == "__main__":
95+
main()
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-unsafe
7+
8+
"""
9+
Focused test for FA4 Eq. 6 conditional rescaling in CK FMHA.
10+
11+
Tests that the conditional rescaling optimization (threshold tau = log2(256) = 8.0)
12+
produces correct results by comparing against PyTorch reference attention.
13+
14+
The key insight: when the running max changes by <= tau between blocks, we can skip
15+
rescaling O_acc and instead correct P. This test exercises both branches:
16+
- Rescale branch: large max change between blocks (e.g., first block)
17+
- Skip branch: small max change between blocks (uniform-ish attention scores)
18+
"""
19+
20+
import math
21+
import unittest
22+
23+
import torch
24+
25+
26+
def ref_attention(q, k, v, scale=None):
27+
"""Reference attention implementation using PyTorch.
28+
29+
Input format: BMHK (batch, seqlen, heads, head_dim).
30+
"""
31+
if scale is None:
32+
scale = 1.0 / math.sqrt(q.shape[-1])
33+
# Transpose to BHMK for matmul: (B, M, H, K) -> (B, H, M, K)
34+
q_t = q.transpose(1, 2)
35+
k_t = k.transpose(1, 2)
36+
v_t = v.transpose(1, 2)
37+
attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale # (B, H, M, M)
38+
attn = torch.softmax(attn, dim=-1)
39+
out = torch.matmul(attn, v_t) # (B, H, M, K)
40+
return out.transpose(1, 2) # Back to BMHK
41+
42+
43+
class CkFa4RescaleTest(unittest.TestCase):
44+
"""Test CK FMHA forward pass correctness with FA4 conditional rescaling."""
45+
46+
def _run_ck_fmha(self, q, k, v):
47+
"""Run CK FMHA forward pass."""
48+
from mslk.attention import fmha
49+
50+
out = fmha.memory_efficient_attention_forward(
51+
q, k, v, attn_bias=None, op=fmha.ck.FwOp
52+
)
53+
return out
54+
55+
def _compare(self, q, k, v, atol, rtol, msg=""):
56+
"""Compare CK FMHA output against reference."""
57+
ref = ref_attention(q.float(), k.float(), v.float()).to(q.dtype)
58+
out = self._run_ck_fmha(q, k, v)
59+
60+
self.assertFalse(out.isnan().any(), f"Output has NaNs {msg}")
61+
62+
max_diff = (out.float() - ref.float()).abs().max().item()
63+
mean_diff = (out.float() - ref.float()).abs().mean().item()
64+
print(f" {msg}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
65+
66+
torch.testing.assert_close(
67+
out.float(),
68+
ref.float(),
69+
atol=atol,
70+
rtol=rtol,
71+
msg=lambda m: f"{msg}: {m}",
72+
)
73+
74+
def test_uniform_scores_bf16(self):
75+
"""Uniform Q/K: all blocks have similar max -> skip branch dominant."""
76+
torch.manual_seed(42)
77+
B, H, M, K = 2, 8, 1024, 128
78+
# Small uniform values -> max changes slowly between blocks
79+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.1
80+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.1
81+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
82+
self._compare(q, k, v, atol=2e-2, rtol=1e-2, msg="uniform_bf16")
83+
84+
def test_large_seqlen_bf16(self):
85+
"""Long sequence: exercises many KV blocks, both branches."""
86+
torch.manual_seed(123)
87+
B, H, M, K = 1, 4, 4096, 64
88+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
89+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
90+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
91+
self._compare(q, k, v, atol=3e-2, rtol=1e-2, msg="large_seqlen_bf16")
92+
93+
def test_spike_pattern_bf16(self):
94+
"""Spike pattern: one KV position has much larger dot product.
95+
96+
Forces rescale branch on the block containing the spike,
97+
and skip branch on surrounding blocks.
98+
"""
99+
torch.manual_seed(99)
100+
B, H, M, K = 2, 4, 2048, 128
101+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.1
102+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.1
103+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
104+
# Insert spike at position 512 -> forces rescale when that block is hit
105+
k[:, 512, :, :] = 10.0
106+
self._compare(q, k, v, atol=3e-2, rtol=1e-2, msg="spike_bf16")
107+
108+
def test_multiple_spikes_bf16(self):
109+
"""Multiple spikes at different KV positions.
110+
111+
Creates a pattern where some blocks trigger rescale and others don't,
112+
exercising the interleaving of both branches.
113+
"""
114+
torch.manual_seed(77)
115+
B, H, M, K = 2, 4, 4096, 128
116+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.05
117+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16) * 0.05
118+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
119+
# Spikes at various positions
120+
for pos in [256, 1024, 2048, 3072]:
121+
k[:, pos, :, :] = 8.0
122+
self._compare(q, k, v, atol=3e-2, rtol=1e-2, msg="multi_spike_bf16")
123+
124+
def test_fp16(self):
125+
"""FP16 precision test."""
126+
torch.manual_seed(42)
127+
B, H, M, K = 2, 8, 2048, 64
128+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.float16)
129+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.float16)
130+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.float16)
131+
self._compare(q, k, v, atol=2e-2, rtol=1e-2, msg="fp16")
132+
133+
def test_deterministic(self):
134+
"""Verify deterministic output across multiple runs."""
135+
torch.manual_seed(42)
136+
B, H, M, K = 1, 4, 2048, 128
137+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
138+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
139+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
140+
out1 = self._run_ck_fmha(q, k, v)
141+
out2 = self._run_ck_fmha(q, k, v)
142+
self.assertTrue(
143+
torch.equal(out1, out2),
144+
f"Non-deterministic: max_diff={((out1 - out2).abs().max().item())}",
145+
)
146+
147+
def test_perf_benefit(self):
148+
"""Measure performance of CK FMHA to verify no regression.
149+
150+
The FA4 conditional rescaling should improve or maintain perf
151+
by skipping unnecessary exp2+multiply operations.
152+
"""
153+
torch.manual_seed(42)
154+
B, H, M, K = 4, 32, 4096, 128
155+
q = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
156+
k = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
157+
v = torch.randn(B, M, H, K, device="cuda", dtype=torch.bfloat16)
158+
159+
# Warmup
160+
for _ in range(5):
161+
self._run_ck_fmha(q, k, v)
162+
torch.cuda.synchronize()
163+
164+
# Benchmark
165+
n_iters = 20
166+
start = torch.cuda.Event(enable_timing=True)
167+
end = torch.cuda.Event(enable_timing=True)
168+
169+
start.record()
170+
for _ in range(n_iters):
171+
self._run_ck_fmha(q, k, v)
172+
end.record()
173+
torch.cuda.synchronize()
174+
175+
elapsed_ms = start.elapsed_time(end) / n_iters
176+
# FLOPs: 2 * B * H * M * M * K (QK^T) + 2 * B * H * M * M * K (PV)
177+
flops = 4 * B * H * M * M * K
178+
tflops = flops / (elapsed_ms * 1e-3) / 1e12
179+
print(f"\n CK FMHA Perf: {elapsed_ms:.2f} ms/iter, {tflops:.1f} TFLOPS")
180+
print(f" Shape: B={B}, H={H}, M={M}, K={K}, dtype=bf16")

0 commit comments

Comments
 (0)