Skip to content

Commit 19ca5e6

Browse files
authored
Make the result consistent with FAv4 benchmark
Differential Revision: D80984604 Pull Request resolved: #356
1 parent b824703 commit 19ca5e6

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

tritonbench/operators/blackwell_attentions/generate_inputs.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _generated_qkv_inputs(
1414
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1515
requires_grad = True
1616

17-
BATCH, H, N_CTX, N_CTX_KV, D_HEAD = shape
17+
BATCH, H, N_HEADS_KV, N_CTX, N_CTX_KV, D_HEAD = shape
1818

1919
q = torch.randn(
2020
(BATCH, H, N_CTX, D_HEAD),
@@ -23,13 +23,13 @@ def _generated_qkv_inputs(
2323
requires_grad=requires_grad,
2424
)
2525
k = torch.randn(
26-
(BATCH, H, N_CTX_KV, D_HEAD),
26+
(BATCH, N_HEADS_KV, N_CTX_KV, D_HEAD),
2727
dtype=dtype,
2828
device=device,
2929
requires_grad=requires_grad,
3030
)
3131
v = torch.randn(
32-
(BATCH, H, N_CTX_KV, D_HEAD),
32+
(BATCH, N_HEADS_KV, N_CTX_KV, D_HEAD),
3333
dtype=dtype,
3434
device=device,
3535
requires_grad=requires_grad,
@@ -42,27 +42,31 @@ def _generated_qkv_inputs(
4242

4343

4444
def customized_inputs(shape, num_inputs, dtype, device) -> Generator:
45-
BATCH, H, SEQ_LEN, SEQ_LEN_KV, D_HEAD = shape
45+
BATCH, H, N_HEADS_KV, SEQ_LEN, SEQ_LEN_KV, D_HEAD = shape
4646

4747
SEQ_LEN_LOG2 = 7
4848

4949
if SEQ_LEN is not None:
5050
SEQ_LEN_KV = SEQ_LEN if SEQ_LEN_KV is None else SEQ_LEN_KV
5151
if num_inputs is None:
5252
yield _generated_qkv_inputs(
53-
(BATCH, H, SEQ_LEN, SEQ_LEN_KV, D_HEAD), dtype=dtype, device=device
53+
(BATCH, H, N_HEADS_KV, SEQ_LEN, SEQ_LEN_KV, D_HEAD),
54+
dtype=dtype,
55+
device=device,
5456
)
5557
else:
5658
for _i in range(num_inputs):
5759
yield _generated_qkv_inputs(
58-
(BATCH, H, SEQ_LEN, SEQ_LEN, D_HEAD), dtype=dtype, device=device
60+
(BATCH, H, N_HEADS_KV, SEQ_LEN, SEQ_LEN, D_HEAD),
61+
dtype=dtype,
62+
device=device,
5963
)
6064
SEQ_LEN *= 2
6165
return
6266
for i in range(SEQ_LEN_LOG2, 15):
6367
SEQ_LEN = 2**i
6468
yield _generated_qkv_inputs(
65-
(BATCH, H, SEQ_LEN, SEQ_LEN, D_HEAD), dtype=dtype, device=device
69+
(BATCH, H, H, SEQ_LEN, SEQ_LEN, D_HEAD), dtype=dtype, device=device
6670
)
6771

6872

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
# [Optional] CuTe
4949
try:
50-
import flash_attn.cute.interface as facute
50+
from flash_attn.cute.interface import flash_attn_func as facute_flash_attn_func
5151

5252
HAS_FLASH_CUTE = True
5353
except (ImportError, IOError, AttributeError):
@@ -98,6 +98,9 @@ def parse_op_args(args: List[str]):
9898
"--seq-len-kv", type=int, default=None, help="Sequence length kv"
9999
)
100100
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
101+
parser.add_argument(
102+
"--n-heads-kv", type=int, default=None, help="Number of heads kv"
103+
)
101104
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
102105
parser.add_argument(
103106
"--causal",
@@ -136,6 +139,9 @@ def __init__(
136139
self.SEQ_LEN_KV = (
137140
args.seq_len_kv if args.seq_len_kv is not None else args.seq_len
138141
)
142+
self.N_HEAD_KV = (
143+
args.n_heads_kv if args.n_heads_kv is not None else args.n_heads
144+
)
139145
self.H = args.n_heads
140146
self.D_HEAD = args.d_head
141147
self.causal = args.causal
@@ -288,7 +294,9 @@ def cutedsl_blackwell(
288294
q = q.transpose(1, 2).contiguous()
289295
k = k.transpose(1, 2).contiguous()
290296
v = v.transpose(1, 2).contiguous()
291-
return lambda: facute.flash_attn_func(q, k, v, self.sm_scale, self.causal)
297+
return lambda: facute_flash_attn_func(
298+
q, k, v, softmax_scale=self.sm_scale, causal=self.causal
299+
)
292300

293301
@register_benchmark()
294302
def flex_attention(self, q, k, v):
@@ -372,7 +380,14 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
372380
def get_input_iter(self) -> Generator:
373381
if self.input_types == "CUSTOMIZED_SHAPES":
374382
return customized_inputs(
375-
shape=(self.BATCH, self.H, self.SEQ_LEN, self.SEQ_LEN_KV, self.D_HEAD),
383+
shape=(
384+
self.BATCH,
385+
self.H,
386+
self.N_HEAD_KV,
387+
self.SEQ_LEN,
388+
self.SEQ_LEN_KV,
389+
self.D_HEAD,
390+
),
376391
num_inputs=self.tb_args.num_inputs,
377392
dtype=self.dtype,
378393
device=self.device,
@@ -386,9 +401,9 @@ def get_input_iter(self) -> Generator:
386401
else:
387402
raise AssertionError(f"Unknown input type {self.input_types}")
388403

389-
@register_x_val(label="(Batch, Heads, SeqLen, SeqLen_KV, Dhead)")
404+
@register_x_val(label="(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)")
390405
def get_x_val(self, example_inputs) -> float:
391406
q, k, v = example_inputs
392407
B, H, S, D = q.shape
393-
_, _, S_KV, _ = k.shape
394-
return (B, H, S, S_KV, D)
408+
_, H_KV, S_KV, _ = k.shape
409+
return (B, H, H_KV, S, S_KV, D)

0 commit comments

Comments
 (0)