@@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
128128 [512 , 32768 , 8192 ],
129129 [1024 , 28672 , 8192 ],
130130 [3072 , 4096 , 3072 ],
131+ [4096 , 4096 , 4096 ],
131132 ],
132133 line_arg = 'provider' ,
133134 # argument name whose value corresponds to a different line in the plot
@@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
152153 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (lambda : torch .matmul (a , b ), n_warmup = 10 , n_repeat = 10 ,
153154 quantiles = quantiles )
154155 elif provider == 'triton' :
155- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
156+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
156157 triton_fn = lambda : matmul (a , b , c )
157158 torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
158159 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
159160 benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
160161 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
161162 quantiles = quantiles , kernel_name = '_kernel' )
162163 elif provider == 'xetla' :
163- c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
164- acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
165- cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
164+ c = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
165+ acc = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .float32 )
166+ cnt = torch .zeros ((M , N ), device = 'xpu' , dtype = torch .int32 )
166167
167168 name = f'gemm_splitk_shape_{ M } _{ K } _{ N } '
168169 func = getattr (xetla_kernel , name )
0 commit comments