Skip to content

Commit a3611b6

Browse files
sijiacfacebook-github-bot
authored andcommitted
paged per-token benchmark
Summary: Add aiter paged kv to the decoding attention benchmark Reviewed By: zjing14 Differential Revision: D69906734 fbshipit-source-id: 8b561a155a1d41d15fee916db999ed1d99a4c4b0
1 parent 7c39cf2 commit a3611b6

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@
6464
register_x_val,
6565
)
6666

67+
# [AMD only] aiter backend
68+
HAS_AITER = True
69+
try:
70+
import aiter_ops
71+
except (ImportError, IOError, AttributeError):
72+
HAS_AITER = False
73+
6774

6875
def parse_op_args(args: List[str]):
6976
parser = argparse.ArgumentParser()
@@ -83,6 +90,15 @@ def parse_op_args(args: List[str]):
8390
from dataclasses import astuple, dataclass
8491

8592

93+
"""
94+
- Runbook
95+
- Nvidia:
96+
buck2 run @mode/opt @mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12.4 //pytorch/tritonbench:run -- --op decoding_attention --cudagraph --csv
97+
- AMD:
98+
buck2 run @mode/opt-amd-gpu @mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.rocm_arch=mi300 //pytorch/tritonbench:run -- --op decoding_attention --cudagraph --csv
99+
"""
100+
101+
86102
@dataclass
87103
class _Shape:
88104
batch: int
@@ -153,6 +169,95 @@ def _pack_xformer_input(
153169
return q, k, v, attn_bias
154170

155171

172+
def get_dtype_max(dtype):
173+
try:
174+
dtypeMax = torch.finfo(dtype).max
175+
except:
176+
dtypeMax = torch.iinfo(dtype).max
177+
return dtypeMax
178+
179+
180+
def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch.int8):
181+
if x_scale is None:
182+
hidden_states = x
183+
else:
184+
# smooth quant
185+
hidden_states = x.to(x_scale) * x_scale
186+
# [m, 1]
187+
per_token_amax, _ = torch.max(input=torch.abs(hidden_states), dim=-1, keepdim=True)
188+
189+
dtypeMax = get_dtype_max(quant_dtype)
190+
191+
per_token_scale = per_token_amax.to(dtype=torch.float32) / dtypeMax
192+
per_token_scale[per_token_scale == 0] = 1
193+
194+
# quant hidden_states
195+
y = (hidden_states / per_token_scale).to(dtype=quant_dtype)
196+
y_scale = per_token_scale.to(y_scale_dtype)
197+
return y, y_scale
198+
199+
200+
def pertoken_quant_kvcache_symm(
201+
# [num_blocks, num_heads, head_size // x, block_size, x]
202+
k_cache: torch.Tensor,
203+
# [num_blocks, num_heads, head_size, block_size]
204+
v_cache: torch.Tensor,
205+
quant_dtype: torch.dtype, # e.g. torch.float8_e4m3fnuz
206+
scale_dtype: torch.dtype = torch.float32,
207+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
208+
num_blocks = k_cache.shape[0]
209+
num_heads = k_cache.shape[1]
210+
head_dim = v_cache.shape[2]
211+
block_size = v_cache.shape[3]
212+
total_tokens = num_blocks * block_size
213+
214+
k_cache_permute = (
215+
k_cache.permute(0, 1, 3, 2, 4)
216+
.reshape(num_blocks, num_heads, block_size, -1)
217+
.contiguous()
218+
)
219+
v_cache_permute = (
220+
v_cache.permute(0, 1, 3, 2)
221+
.reshape(num_blocks, num_heads, block_size, -1)
222+
.contiguous()
223+
)
224+
225+
k_quant, k_scale_asm = pertoken_quant(
226+
k_cache_permute, scale_dtype, quant_dtype=quant_dtype
227+
)
228+
v_quant, v_scale_asm = pertoken_quant(
229+
v_cache_permute, scale_dtype, quant_dtype=quant_dtype
230+
)
231+
232+
# NOTE: quant_x and original x could be different
233+
quant_x = 16 // quant_dtype.itemsize
234+
235+
k_quant = (
236+
k_quant.view(num_blocks, num_heads, block_size, head_dim // quant_x, quant_x)
237+
.permute(0, 1, 3, 2, 4)
238+
.contiguous()
239+
)
240+
k_scale = k_scale_asm.permute(1, 0, 2, 3).contiguous().view(num_heads, total_tokens)
241+
v_quant = (
242+
v_quant.view(num_blocks, num_heads, block_size, head_dim)
243+
.permute(0, 1, 3, 2)
244+
.contiguous()
245+
)
246+
v_scale = v_scale_asm.permute(1, 0, 2, 3).contiguous().view(num_heads, total_tokens)
247+
248+
return k_quant, k_scale, v_quant, v_scale, k_scale_asm, v_scale_asm
249+
250+
251+
def asm_V_shuffle(VC):
252+
# [num_blocks, num_kv_heads, head_size, block_size]
253+
x = 16 // VC.element_size()
254+
num_blocks, num_kv_heads, head_size, block_size = VC.shape
255+
VC = VC.view(num_blocks, num_kv_heads, head_size, block_size // x, x)
256+
# [num_blocks, num_kv_heads, block_size/X, head_size, X]
257+
VC = VC.permute(0, 1, 3, 2, 4).contiguous()
258+
return VC
259+
260+
156261
class Operator(BenchmarkOperator):
157262
DEFAULT_PRECISION = "bf16"
158263

@@ -453,3 +558,74 @@ def fbgemm_gqa_fp8kv(
453558
use_tensor_cores=True,
454559
cache_logical_dtype_int=1, # FP8 = 1
455560
)
561+
562+
@register_benchmark(enabled=HAS_AITER)
563+
def aiter_paged_fp8kv(
564+
self,
565+
q: torch.Tensor,
566+
k_cache: torch.Tensor,
567+
v_cache: torch.Tensor,
568+
cache_seqlens: torch.Tensor,
569+
) -> Callable:
570+
ori_dtype = q.dtype
571+
dtype = torch.float8_e4m3fnuz
572+
573+
num_seqs = k_cache.shape[0]
574+
max_seq_len = k_cache.shape[1]
575+
block_size = 16
576+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
577+
num_blocks = max_num_blocks_per_seq * num_seqs
578+
head_size = k_cache.shape[3]
579+
num_heads = k_cache.shape[2]
580+
581+
x = 16 // ori_dtype.itemsize
582+
k_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
583+
v_cache_shape = (num_blocks, num_heads, head_size, block_size)
584+
_k_cache = torch.rand(k_cache_shape, dtype=ori_dtype, device=self.device)
585+
_v_cache = torch.rand(v_cache_shape, dtype=ori_dtype, device=self.device)
586+
587+
k_quant, k_scale, v_quant, v_scale, k_scale_asm, v_scale_asm = (
588+
pertoken_quant_kvcache_symm(
589+
_k_cache, _v_cache, dtype, scale_dtype=torch.float32
590+
)
591+
)
592+
593+
# total_tokens = num_blocks * block_size
594+
# k_scale = torch.ones(
595+
# (num_heads, total_tokens), dtype=torch.float32, device=self.device
596+
# )
597+
# v_scale = torch.ones_like(k_scale)
598+
599+
available_blocks = list(range(num_blocks)) # Blocks 0 to num_blocks-1
600+
# available_blocks = [0] * num_blocks
601+
block_tables_list = []
602+
for _ in range(num_seqs):
603+
block_tables = available_blocks[:max_num_blocks_per_seq]
604+
available_blocks = available_blocks[max_num_blocks_per_seq:]
605+
block_tables_list.append(block_tables)
606+
607+
block_tables = torch.tensor(
608+
block_tables_list, dtype=torch.int, device=self.device
609+
)
610+
611+
num_query_heads = q.shape[2]
612+
num_kv_heads = num_heads
613+
uniform_range = (-1, 1)
614+
query = torch.empty_strided(
615+
(num_seqs, num_query_heads, head_size),
616+
((num_query_heads + 2 * num_kv_heads) * head_size, head_size, 1),
617+
dtype=ori_dtype,
618+
device=self.device,
619+
)
620+
query.uniform_(*uniform_range)
621+
622+
return lambda: aiter_ops.pa_fwd_asm(
623+
query.contiguous(),
624+
k_quant,
625+
asm_V_shuffle(v_quant),
626+
block_tables,
627+
cache_seqlens,
628+
max_num_blocks_per_seq,
629+
k_scale_asm,
630+
v_scale_asm,
631+
)

0 commit comments

Comments
 (0)