66To compare the performance to XeTLA kernel.
77
88"""
9+ import os
910
1011import torch
1112import triton
1718if benchmark_suit .USE_IPEX_OPTION :
1819 import intel_extension_for_pytorch # type: ignore # noqa: F401
1920
21+ TRANSPOSE_A = os .getenv ('TRANSPOSE_A' , '0' ) == '1'
22+ TRANSPOSE_B = os .getenv ('TRANSPOSE_B' , '0' ) == '1'
23+ use_xetla = not (TRANSPOSE_A or TRANSPOSE_B )
24+
2025
2126@triton .autotune (
2227 configs = [
@@ -158,15 +163,22 @@ def matmul_kernel_with_block_pointers_batched(
158163
159164# We can now create a convenience wrapper function that only takes two input tensors,
160165# and (1) checks any shape constraint; (2) launches the above kernel.
161- def matmul (a , b , c ):
166+ def matmul (a , b , c , transpose_a = False , transpose_b = False ):
167+ a_major , a_minor = - 2 , - 1
168+ if transpose_a :
169+ a_major , a_minor = a_minor , a_major
170+ b_minor , b_major = - 2 , - 1
171+ if transpose_b :
172+ b_major , b_minor = b_minor , b_major
173+
174+ assert a .shape [a_minor ] == b .shape [b_minor ], 'Incompatible dimensions'
175+ assert a .is_contiguous (), 'Matrix A must be contiguous'
176+ assert b .is_contiguous (), 'Matrix B must be contiguous'
177+ M , N , K = a .shape [a_major ], b .shape [b_major ], a .shape [a_minor ]
162178 # Check constraints.
163179 if len (a .shape ) == 3 and len (b .shape ) == 3 :
164180 assert a .shape [0 ] == b .shape [0 ], 'Incompatible Batch dimension'
165- assert a .shape [2 ] == b .shape [1 ], 'Incompatible dimensions'
166- assert a .is_contiguous (), 'Matrix A must be contiguous'
167- assert b .is_contiguous (), 'Matrix B must be contiguous'
168- B , M , K = a .shape
169- B , K , N = b .shape
181+ B = a .shape [0 ]
170182 # 1D launch kernel where each block gets its own program.
171183 grid = lambda META : (
172184 B ,
@@ -175,27 +187,37 @@ def matmul(a, b, c):
175187 matmul_kernel_with_block_pointers_batched [grid ](
176188 a , b , c , #
177189 B , M , N , K , #
178- a .stride (0 ), a .stride (1 ), a .stride (2 ), #
179- b .stride (0 ), b .stride (1 ), b .stride (2 ), #
190+ a .stride (0 ), a .stride (a_major ), a .stride (a_minor ), #
191+ b .stride (0 ), b .stride (b_minor ), b .stride (b_major ), #
180192 c .stride (0 ), c .stride (1 ), c .stride (2 ))
181193 elif len (a .shape ) == 2 and len (b .shape ) == 2 :
182- assert a .shape [1 ] == b .shape [0 ], 'Incompatible dimensions'
183- assert a .is_contiguous (), 'Matrix A must be contiguous'
184- assert b .is_contiguous (), 'Matrix B must be contiguous'
185- M , K = a .shape
186- K , N = b .shape
187194 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
188195 matmul_kernel_with_block_pointers [grid ](
189196 a , b , c , #
190197 M , N , K , #
191- a .stride (0 ), a .stride (1 ), #
192- b .stride (0 ), b .stride (1 ), #
198+ a .stride (a_major ), a .stride (a_minor ), #
199+ b .stride (b_minor ), b .stride (b_major ), #
193200 c .stride (0 ), c .stride (1 ))
194201 else :
195202 assert False , 'Input matrixs dimensions mismatch'
196203 return c
197204
198205
206+ def get_shapes (B , M , N , K , transpose_a , transpose_b ):
207+ a_shape = (M , K )
208+ if transpose_a :
209+ a_shape = (K , M )
210+
211+ b_shape = (K , N )
212+ if transpose_b :
213+ b_shape = (N , K )
214+
215+ if B != 1 :
216+ a_shape = (B , * a_shape )
217+ b_shape = (B , * b_shape )
218+ return a_shape , b_shape
219+
220+
199221# Benchmark Performance
200222@benchmark_suit .perf_report (
201223 benchmark_suit .Benchmark (
@@ -228,9 +250,9 @@ def matmul(a, b, c):
228250 line_arg = 'provider' ,
229251 # argument name whose value corresponds to a different line in the plot
230252 # possible values for `line_arg``
231- line_vals = ['triton' , 'xetla' ],
253+ line_vals = ['triton' ] + ([ 'xetla' ] if use_xetla else []) ,
232254 # label name for the lines
233- line_names = ['Triton' , 'XeTLA' ],
255+ line_names = ['Triton' ] + ([ 'XeTLA' ] if use_xetla else []) ,
234256 # line styles
235257 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
236258 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
@@ -239,27 +261,33 @@ def matmul(a, b, c):
239261 args = {},
240262 ))
241263def benchmark (B , M , N , K , provider ):
242- if B == 1 :
243- a = torch .rand ((M , K ), device = 'xpu' , dtype = torch .bfloat16 )
244- b = torch .rand ((K , N ), device = 'xpu' , dtype = torch .bfloat16 )
245- else :
246- a = torch .rand ((B , M , K ), device = 'xpu' , dtype = torch .bfloat16 )
247- b = torch .rand ((B , K , N ), device = 'xpu' , dtype = torch .bfloat16 )
264+ a_shape , b_shape = get_shapes (B , M , N , K , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
265+
266+ a = torch .rand (a_shape , device = 'xpu' , dtype = torch .bfloat16 )
267+ b = torch .rand (b_shape , device = 'xpu' , dtype = torch .bfloat16 )
248268
249269 quantiles = [0.5 , 0.0 , 1.0 ]
250270
271+ torch_a = a
272+ if TRANSPOSE_A :
273+ torch_a = torch .transpose (torch_a , - 2 , - 1 )
274+
275+ torch_b = b
276+ if TRANSPOSE_B :
277+ torch_b = torch .transpose (torch_b , - 2 , - 1 )
278+
251279 if provider == 'onednn' :
252- _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (lambda : torch .matmul (a , b ), warmup = 10 , rep = 10 ,
253- quantiles = quantiles )
280+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (lambda : torch .matmul (torch_a , torch_b ), warmup = 10 ,
281+ rep = 10 , quantiles = quantiles )
254282 elif provider == 'triton' :
255283 assert len (a .shape ) == len (b .shape ), 'Incompatible sizes'
256284 if len (a .shape ) == 3 :
257285 c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
258286 else :
259287 assert len (a .shape ) == 2 , 'Expecting shape of length 2'
260288 c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
261- triton_fn = lambda : matmul (a , b , c )
262- torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
289+ triton_fn = lambda : matmul (a , b , c , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
290+ torch_fn = lambda : torch .matmul (torch_a , torch_b ).to (torch .float32 )
263291 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
264292 benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
265293 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , warmup = 10 , rep = 10 , quantiles = quantiles ,
0 commit comments