55This benchmark is modified from gemm_benchmark.py to add a matrix to the output of the gemm operation.
66
77"""
8+ import os
89
910import torch
1011import triton
1112import triton .language as tl
1213
1314import triton_kernels_benchmark as benchmark_suit
1415
16+ INT8_ONLY_OPTION = os .getenv ('INT8_ONLY' , '0' ) == '1'
17+ ALL_DTYPES_OPTION = os .getenv ('ALL_DTYPES' , '0' ) == '1'
18+
19+
20+ def dtypes ():
21+ if ALL_DTYPES_OPTION :
22+ return [torch .bfloat16 , torch .int8 ]
23+ if INT8_ONLY_OPTION :
24+ return [torch .int8 ]
25+ return [torch .bfloat16 ]
26+
27+
28+ def suffix ():
29+ if ALL_DTYPES_OPTION :
30+ return 'all'
31+ if INT8_ONLY_OPTION :
32+ return 'int8'
33+ return 'bfloat16'
34+
1535
1636@triton .autotune (
1737 configs = [
@@ -43,7 +63,8 @@ def matmul_kernel_with_block_pointers(
4363 stride_am : tl .constexpr , stride_ak : tl .constexpr , #
4464 stride_bk : tl .constexpr , stride_bn : tl .constexpr , #
4565 stride_cm : tl .constexpr , stride_cn : tl .constexpr , #
46- stride_dm : tl .constexpr , stride_dn : tl .constexpr ,
66+ stride_dm : tl .constexpr , stride_dn : tl .constexpr , #
67+ ACCUMULATOR_DTYPE : tl .constexpr ,
4768 # Meta-parameters
4869 BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr , GROUP_SIZE_M : tl .constexpr ):
4970 pid = tl .program_id (axis = 0 )
@@ -63,7 +84,7 @@ def matmul_kernel_with_block_pointers(
6384 offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
6485 order = (1 , 0 ))
6586
66- accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl . float32 )
87+ accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ACCUMULATOR_DTYPE )
6788 for _ in range (0 , K , BLOCK_SIZE_K ):
6889 a = tl .load (a_block_ptr , boundary_check = (0 , 1 ))
6990 b = tl .load (b_block_ptr , boundary_check = (0 , 1 ))
@@ -117,7 +138,8 @@ def matmul_kernel_with_block_pointers_batched(
117138 stride_az : tl .constexpr , stride_am : tl .constexpr , stride_ak : tl .constexpr , #
118139 stride_bz : tl .constexpr , stride_bk : tl .constexpr , stride_bn : tl .constexpr , #
119140 stride_cz : tl .constexpr , stride_cm : tl .constexpr , stride_cn : tl .constexpr , #
120- stride_dz : tl .constexpr , stride_dm : tl .constexpr , stride_dn : tl .constexpr ,
141+ stride_dz : tl .constexpr , stride_dm : tl .constexpr , stride_dn : tl .constexpr , #
142+ ACCUMULATOR_DTYPE : tl .constexpr ,
121143 # Meta-parameters
122144 BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr , GROUP_SIZE_M : tl .constexpr ):
123145 bid = tl .program_id (axis = 0 )
@@ -141,7 +163,7 @@ def matmul_kernel_with_block_pointers_batched(
141163 offsets = (0 , pid_n * BLOCK_SIZE_N ), block_shape = (BLOCK_SIZE_K , BLOCK_SIZE_N ),
142164 order = (1 , 0 ))
143165
144- accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl . float32 )
166+ accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = ACCUMULATOR_DTYPE )
145167 for _ in range (0 , K , BLOCK_SIZE_K ):
146168 a = tl .load (a_block_ptr , boundary_check = (0 , 1 ))
147169 b = tl .load (b_block_ptr , boundary_check = (0 , 1 ))
@@ -185,7 +207,8 @@ def matmul(a, b, d, c):
185207 a .stride (0 ), a .stride (1 ), a .stride (2 ), #
186208 b .stride (0 ), b .stride (1 ), b .stride (2 ), #
187209 c .stride (0 ), c .stride (1 ), c .stride (2 ), #
188- d .stride (0 ), d .stride (1 ), d .stride (2 ))
210+ d .stride (0 ), d .stride (1 ), d .stride (2 ), #
211+ tl .float32 if a .dtype .is_floating_point else tl .int32 )
189212 elif len (a .shape ) == 2 and len (b .shape ) == 2 :
190213 assert a .shape [1 ] == b .shape [0 ], 'Incompatible dimensions'
191214 assert a .is_contiguous (), 'Matrix A must be contiguous'
@@ -199,7 +222,8 @@ def matmul(a, b, d, c):
199222 a .stride (0 ), a .stride (1 ), #
200223 b .stride (0 ), b .stride (1 ), #
201224 c .stride (0 ), c .stride (1 ), #
202- d .stride (0 ), d .stride (1 ))
225+ d .stride (0 ), d .stride (1 ), #
226+ tl .float32 if a .dtype .is_floating_point else tl .int32 )
203227 else :
204228 assert False , 'Input matrixs dimensions mismatch'
205229 return c
@@ -209,10 +233,10 @@ def matmul(a, b, d, c):
209233@benchmark_suit .perf_report (
210234 benchmark_suit .Benchmark (
211235 # argument names to use as an x-axis for the plot
212- x_names = ['B' , 'M' , 'K' , 'N' ],
236+ x_names = ['B' , 'M' , 'K' , 'N' , 'dtype' ],
213237 # different possible values for `x_name`
214- x_vals = [[1 , 1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + #
215- [ #
238+ x_vals = [[1 , 1024 * i , 1024 * i , 1024 * i , dtype ] for i in [1 , 2 , 4 , 8 ] for dtype in dtypes () ] + #
239+ [[ * shape , dtype ] for shape in [ #
216240 [1 , 1 , 5120 , 13824 ], #
217241 [1 , 4 , 4096 , 12288 ], #
218242 [1 , 512 , 8192 , 8192 ], #
@@ -232,8 +256,8 @@ def matmul(a, b, d, c):
232256 [4 , 32768 , 4096 , 128 ], #
233257 [32 , 4096 , 4096 , 128 ], #
234258 [4096 , 8 , 128 , 16384 ], #
235- [4096 , 8 , 16384 , 128 ]
236- ],
259+ [4096 , 8 , 16384 , 128 ] #
260+ ] for dtype in dtypes ()] ,
237261 line_arg = 'provider' ,
238262 # argument name whose value corresponds to a different line in the plot
239263 # possible values for `line_arg``
@@ -243,33 +267,46 @@ def matmul(a, b, d, c):
243267 # line styles
244268 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
245269 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
246- plot_name = 'matmul-performance-postop-addmatrix' ,
270+ plot_name = 'matmul-performance-postop-addmatrix' + '-' + suffix () ,
247271 # name for the plot. Used also as a file name for saving the plot.
248272 args = {},
249273 ))
250- def benchmark (B , M , N , K , provider ):
274+ def benchmark (B , M , N , K , dtype , provider ):
275+ res_dtype = torch .float32 if dtype .is_floating_point else torch .int32
276+ if dtype .is_floating_point :
277+ rand = lambda shape , dtype : torch .rand (shape , device = 'xpu' , dtype = dtype )
278+ else :
279+ rand = lambda shape , dtype : torch .randint (low = - 127 , high = 128 , size = shape , device = 'xpu' , dtype = dtype )
251280 if B == 1 :
252- a = torch . rand ((M , K ), device = 'xpu' , dtype = torch . bfloat16 )
253- b = torch . rand ((K , N ), device = 'xpu' , dtype = torch . bfloat16 )
254- d = torch . rand ((M , N ), device = 'xpu' , dtype = torch . float32 )
281+ a = rand ((M , K ), dtype )
282+ b = rand ((K , N ), dtype )
283+ d = rand ((M , N ), res_dtype )
255284 else :
256- a = torch . rand ((B , M , K ), device = 'xpu' , dtype = torch . bfloat16 )
257- b = torch . rand ((B , K , N ), device = 'xpu' , dtype = torch . bfloat16 )
258- d = torch . rand ((B , M , N ), device = 'xpu' , dtype = torch . float32 )
285+ a = rand ((B , M , K ), dtype )
286+ b = rand ((B , K , N ), dtype )
287+ d = rand ((B , M , N ), res_dtype )
259288
260289 quantiles = [0.5 , 0.0 , 1.0 ]
261290
262291 if provider == 'triton' :
263292 assert len (a .shape ) == len (b .shape ), 'Incompatible sizes'
264293 if len (a .shape ) == 3 :
265- c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch . float32 )
294+ c = torch .empty ((B , M , N ), device = 'xpu' , dtype = res_dtype )
266295 else :
267296 assert len (a .shape ) == 2 , 'Expecting shape of length 2'
268- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch . float32 )
297+ c = torch .empty ((M , N ), device = 'xpu' , dtype = res_dtype )
269298 triton_fn = lambda : matmul (a , b , d , c )
270- torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 ) + d
299+ # Torch does not support integer calculation in matmul
300+ torch_device = 'xpu' if dtype .is_floating_point else 'cpu'
301+ torch_dtype = dtype if dtype .is_floating_point else res_dtype
302+ torch_fn = lambda : torch .matmul (a .to (device = torch_device , dtype = torch_dtype ),
303+ b .to (device = torch_device , dtype = torch_dtype )).to (device = 'xpu' , dtype = res_dtype
304+ ) + d
271305 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
272- benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
306+ if dtype .is_floating_point or [B , M , N , K ] in [[1 , 1024 , 1024 , 1024 ], [1 , 2048 , 2048 , 2048 ],
307+ [1 , 512 , 8192 , 32768 ], [4 , 32768 , 4096 , 128 ]]:
308+ # torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
309+ benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
273310 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
274311 quantiles = quantiles )
275312 else :
0 commit comments