@@ -227,28 +227,28 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
227227@benchmark_suit .perf_report (
228228 benchmark_suit .Benchmark (
229229 # argument names to use as an x-axis for the plot
230- x_names = ['B' , 'M' , 'K ' , 'N ' ],
230+ x_names = ['B' , 'M' , 'N ' , 'K ' ],
231231 # different possible values for `x_name`
232232 x_vals = [[1 , 1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + #
233233 [ #
234- [1 , 1 , 5120 , 13824 ], #
235- [1 , 4 , 4096 , 12288 ], #
234+ [1 , 1 , 13824 , 5120 ], #
235+ [1 , 4 , 12288 , 4096 ], #
236236 [1 , 512 , 8192 , 8192 ], #
237237 [1 , 512 , 8192 , 32768 ], #
238238 [1 , 512 , 32768 , 8192 ], #
239- [1 , 1024 , 16384 , 8192 ], #
240- [1 , 1024 , 28672 , 8192 ], #
241- [1 , 3072 , 4096 , 3072 ], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242- [1 , 4096 , 16384 , 8192 ], #
243- [1 , 8192 , 16384 , 1024 ], #
244- [1 , 8192 , 16384 , 4096 ], #
239+ [1 , 1024 , 8192 , 16384 ], #
240+ [1 , 1024 , 8192 , 28672 ], #
241+ [1 , 3072 , 3072 , 4096 ], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242+ [1 , 4096 , 8192 , 16384 ], #
243+ [1 , 8192 , 1024 , 16384 ], #
244+ [1 , 8192 , 4096 , 16384 ], #
245245 [1 , 16384 , 1024 , 8192 ], #
246246 [1 , 16384 , 4096 , 8192 ], #
247247 [1 , 16384 , 8192 , 1024 ], #
248248 [1 , 16384 , 8192 , 4096 ], #
249249 [4 , 32768 , 128 , 4096 ], #
250250 [4 , 32768 , 4096 , 128 ], #
251- [32 , 4096 , 4096 , 128 ], #
251+ [32 , 4096 , 128 , 4096 ], #
252252 [4096 , 8 , 128 , 16384 ], #
253253 [4096 , 8 , 16384 , 128 ]
254254 ],
@@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
268268def benchmark (B , M , N , K , provider ):
269269 a_shape , b_shape = get_shapes (B , M , N , K , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
270270
271+ torch .manual_seed (0 )
271272 a = torch .rand (a_shape , device = 'xpu' , dtype = torch .bfloat16 )
272273 b = torch .rand (b_shape , device = 'xpu' , dtype = torch .bfloat16 )
273274
@@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider):
291292 elif provider == 'triton' :
292293 assert len (a .shape ) == len (b .shape ), 'Incompatible sizes'
293294 if len (a .shape ) == 3 :
294- c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
295+ c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
295296 else :
296297 assert len (a .shape ) == 2 , 'Expecting shape of length 2'
297- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
298+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
298299 triton_fn = lambda : matmul (a , b , c , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
299300 torch_fn = lambda : torch .matmul (torch_a , torch_b ).to (torch .float32 )
300301 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
@@ -304,17 +305,17 @@ def benchmark(B, M, N, K, provider):
304305 kernel_name = 'matmul_kernel_with_block_pointers' )
305306 elif provider == 'xetla' :
306307 if B == 1 :
307- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
308- acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
309- cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
308+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
309+ acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
310+ cnt = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .int32 )
310311 else :
311- c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
312- acc = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313- cnt = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
312+ c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313+ acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
314+ cnt = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
314315 name = f'gemm_shape_{ B } _{ M } _{ K } _{ N } '
315316 # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316317 # better performance.
317- if (B , M , N , K ) == (1 , 3072 , 4096 , 3072 ):
318+ if (B , M , N , K ) == (1 , 3072 , 3072 , 4096 ):
318319 name = 'gemm_streamk_shape_3072_4096_3072'
319320 func = getattr (xetla_kernel , name )
320321 xetla_fn = lambda : func (a , b , c , acc , cnt )
0 commit comments