64
64
register_x_val ,
65
65
)
66
66
67
+ # [AMD only] aiter backend
68
+ HAS_AITER = True
69
+ try :
70
+ import aiter_ops
71
+ except (ImportError , IOError , AttributeError ):
72
+ HAS_AITER = False
73
+
67
74
68
75
def parse_op_args (args : List [str ]):
69
76
parser = argparse .ArgumentParser ()
@@ -83,6 +90,15 @@ def parse_op_args(args: List[str]):
83
90
from dataclasses import astuple , dataclass
84
91
85
92
93
+ """
94
+ - Runbook
95
+ - Nvidia:
96
+ buck2 run @mode/opt @mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100a -c fbcode.platform010_cuda_version=12.4 //pytorch/tritonbench:run -- --op decoding_attention --cudagraph --csv
97
+ - AMD:
98
+ buck2 run @mode/opt-amd-gpu @mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.rocm_arch=mi300 //pytorch/tritonbench:run -- --op decoding_attention --cudagraph --csv
99
+ """
100
+
101
+
86
102
@dataclass
87
103
class _Shape :
88
104
batch : int
@@ -153,6 +169,95 @@ def _pack_xformer_input(
153
169
return q , k , v , attn_bias
154
170
155
171
172
+ def get_dtype_max (dtype ):
173
+ try :
174
+ dtypeMax = torch .finfo (dtype ).max
175
+ except :
176
+ dtypeMax = torch .iinfo (dtype ).max
177
+ return dtypeMax
178
+
179
+
180
+ def pertoken_quant (x , y_scale_dtype = torch .float , x_scale = None , quant_dtype = torch .int8 ):
181
+ if x_scale is None :
182
+ hidden_states = x
183
+ else :
184
+ # smooth quant
185
+ hidden_states = x .to (x_scale ) * x_scale
186
+ # [m, 1]
187
+ per_token_amax , _ = torch .max (input = torch .abs (hidden_states ), dim = - 1 , keepdim = True )
188
+
189
+ dtypeMax = get_dtype_max (quant_dtype )
190
+
191
+ per_token_scale = per_token_amax .to (dtype = torch .float32 ) / dtypeMax
192
+ per_token_scale [per_token_scale == 0 ] = 1
193
+
194
+ # quant hidden_states
195
+ y = (hidden_states / per_token_scale ).to (dtype = quant_dtype )
196
+ y_scale = per_token_scale .to (y_scale_dtype )
197
+ return y , y_scale
198
+
199
+
200
+ def pertoken_quant_kvcache_symm (
201
+ # [num_blocks, num_heads, head_size // x, block_size, x]
202
+ k_cache : torch .Tensor ,
203
+ # [num_blocks, num_heads, head_size, block_size]
204
+ v_cache : torch .Tensor ,
205
+ quant_dtype : torch .dtype , # e.g. torch.float8_e4m3fnuz
206
+ scale_dtype : torch .dtype = torch .float32 ,
207
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
208
+ num_blocks = k_cache .shape [0 ]
209
+ num_heads = k_cache .shape [1 ]
210
+ head_dim = v_cache .shape [2 ]
211
+ block_size = v_cache .shape [3 ]
212
+ total_tokens = num_blocks * block_size
213
+
214
+ k_cache_permute = (
215
+ k_cache .permute (0 , 1 , 3 , 2 , 4 )
216
+ .reshape (num_blocks , num_heads , block_size , - 1 )
217
+ .contiguous ()
218
+ )
219
+ v_cache_permute = (
220
+ v_cache .permute (0 , 1 , 3 , 2 )
221
+ .reshape (num_blocks , num_heads , block_size , - 1 )
222
+ .contiguous ()
223
+ )
224
+
225
+ k_quant , k_scale_asm = pertoken_quant (
226
+ k_cache_permute , scale_dtype , quant_dtype = quant_dtype
227
+ )
228
+ v_quant , v_scale_asm = pertoken_quant (
229
+ v_cache_permute , scale_dtype , quant_dtype = quant_dtype
230
+ )
231
+
232
+ # NOTE: quant_x and original x could be different
233
+ quant_x = 16 // quant_dtype .itemsize
234
+
235
+ k_quant = (
236
+ k_quant .view (num_blocks , num_heads , block_size , head_dim // quant_x , quant_x )
237
+ .permute (0 , 1 , 3 , 2 , 4 )
238
+ .contiguous ()
239
+ )
240
+ k_scale = k_scale_asm .permute (1 , 0 , 2 , 3 ).contiguous ().view (num_heads , total_tokens )
241
+ v_quant = (
242
+ v_quant .view (num_blocks , num_heads , block_size , head_dim )
243
+ .permute (0 , 1 , 3 , 2 )
244
+ .contiguous ()
245
+ )
246
+ v_scale = v_scale_asm .permute (1 , 0 , 2 , 3 ).contiguous ().view (num_heads , total_tokens )
247
+
248
+ return k_quant , k_scale , v_quant , v_scale , k_scale_asm , v_scale_asm
249
+
250
+
251
+ def asm_V_shuffle (VC ):
252
+ # [num_blocks, num_kv_heads, head_size, block_size]
253
+ x = 16 // VC .element_size ()
254
+ num_blocks , num_kv_heads , head_size , block_size = VC .shape
255
+ VC = VC .view (num_blocks , num_kv_heads , head_size , block_size // x , x )
256
+ # [num_blocks, num_kv_heads, block_size/X, head_size, X]
257
+ VC = VC .permute (0 , 1 , 3 , 2 , 4 ).contiguous ()
258
+ return VC
259
+
260
+
156
261
class Operator (BenchmarkOperator ):
157
262
DEFAULT_PRECISION = "bf16"
158
263
@@ -453,3 +558,74 @@ def fbgemm_gqa_fp8kv(
453
558
use_tensor_cores = True ,
454
559
cache_logical_dtype_int = 1 , # FP8 = 1
455
560
)
561
+
562
+ @register_benchmark (enabled = HAS_AITER )
563
+ def aiter_paged_fp8kv (
564
+ self ,
565
+ q : torch .Tensor ,
566
+ k_cache : torch .Tensor ,
567
+ v_cache : torch .Tensor ,
568
+ cache_seqlens : torch .Tensor ,
569
+ ) -> Callable :
570
+ ori_dtype = q .dtype
571
+ dtype = torch .float8_e4m3fnuz
572
+
573
+ num_seqs = k_cache .shape [0 ]
574
+ max_seq_len = k_cache .shape [1 ]
575
+ block_size = 16
576
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1 ) // block_size
577
+ num_blocks = max_num_blocks_per_seq * num_seqs
578
+ head_size = k_cache .shape [3 ]
579
+ num_heads = k_cache .shape [2 ]
580
+
581
+ x = 16 // ori_dtype .itemsize
582
+ k_cache_shape = (num_blocks , num_heads , head_size // x , block_size , x )
583
+ v_cache_shape = (num_blocks , num_heads , head_size , block_size )
584
+ _k_cache = torch .rand (k_cache_shape , dtype = ori_dtype , device = self .device )
585
+ _v_cache = torch .rand (v_cache_shape , dtype = ori_dtype , device = self .device )
586
+
587
+ k_quant , k_scale , v_quant , v_scale , k_scale_asm , v_scale_asm = (
588
+ pertoken_quant_kvcache_symm (
589
+ _k_cache , _v_cache , dtype , scale_dtype = torch .float32
590
+ )
591
+ )
592
+
593
+ # total_tokens = num_blocks * block_size
594
+ # k_scale = torch.ones(
595
+ # (num_heads, total_tokens), dtype=torch.float32, device=self.device
596
+ # )
597
+ # v_scale = torch.ones_like(k_scale)
598
+
599
+ available_blocks = list (range (num_blocks )) # Blocks 0 to num_blocks-1
600
+ # available_blocks = [0] * num_blocks
601
+ block_tables_list = []
602
+ for _ in range (num_seqs ):
603
+ block_tables = available_blocks [:max_num_blocks_per_seq ]
604
+ available_blocks = available_blocks [max_num_blocks_per_seq :]
605
+ block_tables_list .append (block_tables )
606
+
607
+ block_tables = torch .tensor (
608
+ block_tables_list , dtype = torch .int , device = self .device
609
+ )
610
+
611
+ num_query_heads = q .shape [2 ]
612
+ num_kv_heads = num_heads
613
+ uniform_range = (- 1 , 1 )
614
+ query = torch .empty_strided (
615
+ (num_seqs , num_query_heads , head_size ),
616
+ ((num_query_heads + 2 * num_kv_heads ) * head_size , head_size , 1 ),
617
+ dtype = ori_dtype ,
618
+ device = self .device ,
619
+ )
620
+ query .uniform_ (* uniform_range )
621
+
622
+ return lambda : aiter_ops .pa_fwd_asm (
623
+ query .contiguous (),
624
+ k_quant ,
625
+ asm_V_shuffle (v_quant ),
626
+ block_tables ,
627
+ cache_seqlens ,
628
+ max_num_blocks_per_seq ,
629
+ k_scale_asm ,
630
+ v_scale_asm ,
631
+ )
0 commit comments