@@ -306,19 +306,27 @@ def benchmark(B, M, N, K, provider):
306306 elif provider == 'xetla' :
307307 if B == 1 :
308308 c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
309- acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
310309 cnt = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .int32 )
311310 else :
312311 c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
313- acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
314312 cnt = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .int32 )
315313 name = f'gemm_shape_{ B } _{ M } _{ K } _{ N } '
316314 # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
317315 # better performance.
318316 if (B , M , N , K ) == (1 , 3072 , 3072 , 4096 ):
319317 name = 'gemm_streamk_shape_3072_4096_3072'
320318 func = getattr (xetla_kernel , name )
321- xetla_fn = lambda : func (a , b , c , acc , cnt )
319+
320+ def xetla_func_with_acc_allocation ():
321+ # allocating `acc` matrix on every function call, to be as similar as
322+ # possible to the triton kernel, which also does this on every call.
323+ if B == 1 :
324+ acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
325+ else :
326+ acc = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
327+ return func (a , b , c , acc , cnt )
328+
329+ xetla_fn = xetla_func_with_acc_allocation
322330 torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
323331
324332 kernels_name = {
0 commit comments