@@ -305,19 +305,28 @@ def benchmark(B, M, N, K, provider):
305305 elif provider == 'xetla' :
306306 if B == 1 :
307307 c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
308- acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
309308 cnt = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .int32 )
310309 else :
311310 c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
312- acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313311 cnt = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
314312 name = f'gemm_shape_{ B } _{ M } _{ K } _{ N } '
315313 # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316314 # better performance.
317315 if (B , M , N , K ) == (1 , 3072 , 3072 , 4096 ):
318316 name = 'gemm_streamk_shape_3072_4096_3072'
319317 func = getattr (xetla_kernel , name )
320- xetla_fn = lambda : func (a , b , c , acc , cnt )
318+
319+
320+ def xetla_func_with_acc_allocation ():
321+ # allocating `acc` matrix on every function call, to be as similar as
322+ # possible to the triton kernel, which also does this on every call.
323+ if B == 1 :
324+ acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
325+ else :
326+ acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
327+ return func (a , b , c , acc , cnt )
328+
329+ xetla_fn = xetla_func_with_acc_allocation
321330 torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
322331
323332 # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
0 commit comments