@@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
268268def benchmark (B , M , N , K , provider ):
269269 a_shape , b_shape = get_shapes (B , M , N , K , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
270270
271+ torch .manual_seed (0 )
271272 a = torch .rand (a_shape , device = 'xpu' , dtype = torch .bfloat16 )
272273 b = torch .rand (b_shape , device = 'xpu' , dtype = torch .bfloat16 )
273274
@@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider):
291292 elif provider == 'triton' :
292293 assert len (a .shape ) == len (b .shape ), 'Incompatible sizes'
293294 if len (a .shape ) == 3 :
294- c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
295+ c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
295296 else :
296297 assert len (a .shape ) == 2 , 'Expecting shape of length 2'
297- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
298+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
298299 triton_fn = lambda : matmul (a , b , c , transpose_a = TRANSPOSE_A , transpose_b = TRANSPOSE_B )
299300 torch_fn = lambda : torch .matmul (torch_a , torch_b ).to (torch .float32 )
300301 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
@@ -304,13 +305,13 @@ def benchmark(B, M, N, K, provider):
304305 kernel_name = 'matmul_kernel_with_block_pointers' )
305306 elif provider == 'xetla' :
306307 if B == 1 :
307- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
308- acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
309- cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
308+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
309+ acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
310+ cnt = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .int32 )
310311 else :
311- c = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
312- acc = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313- cnt = torch .empty ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
312+ c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313+ acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
314+ cnt = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
314315 name = f'gemm_shape_{ B } _{ M } _{ K } _{ N } '
315316 # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316317 # better performance.
0 commit comments