33import triton .language as tl
44
55import triton_kernels_benchmark as benchmark_suit
6+ import xetla_kernel
67
78if benchmark_suit .USE_IPEX_OPTION :
89 import intel_extension_for_pytorch # type: ignore # noqa: F401
@@ -131,9 +132,9 @@ def forward(ctx, a, b, c, acc_dtype=None):
131132 line_arg = 'provider' ,
132133 # argument name whose value corresponds to a different line in the plot
133134 # possible values for `line_arg``
134- line_vals = ['triton' ],
135+ line_vals = ['triton' , 'xetla' ],
135136 # label name for the lines
136- line_names = ['Triton' ],
137+ line_names = ['Triton' , 'XeTLA' ],
137138 # line styles
138139 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
139140 ylabel = ['GB/s' , 'TFlops' ], # label name for the y-axis
@@ -148,23 +149,36 @@ def benchmark(M, N, K, provider):
148149 quantiles = [0.5 , 0.0 , 1.0 ]
149150
150151 if provider == 'onednn' :
151- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (lambda : torch .matmul (a , b ), n_warmup = 10 , n_repeat = 10 ,
152- quantiles = quantiles )
152+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (lambda : torch .matmul (a , b ), n_warmup = 10 , n_repeat = 10 ,
153+ quantiles = quantiles )
153154 elif provider == 'triton' :
154155 c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
155156 triton_fn = lambda : matmul (a , b , c )
156157 torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
157158 rtol = 1e-2 if a .dtype == torch .bfloat16 else 1e-3
158159 benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
159- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
160- kernel_name = '_kernel' )
160+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
161+ quantiles = quantiles , kernel_name = '_kernel' )
162+ elif provider == 'xetla' :
163+ c = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
164+ acc = torch .empty ((M , N ), device = 'xpu' , dtype = torch .float32 )
165+ cnt = torch .empty ((M , N ), device = 'xpu' , dtype = torch .int32 )
166+
167+ name = f'gemm_splitk_shape_{ M } _{ K } _{ N } '
168+ func = getattr (xetla_kernel , name )
169+ xetla_fn = lambda : func (a , b , c , acc , cnt )
170+ torch_fn = lambda : torch .matmul (a , b ).to (torch .float32 )
171+
172+ # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
173+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 ,
174+ quantiles = quantiles , kernel_name = 'split_k_gemm_run' )
161175 else :
162176 raise NotImplementedError (f'Unsupported provider { provider } ' )
163177
164178 tflops = lambda mean : 2 * M * N * K * (1e-12 ) / (mean * 1e-3 )
165179 gbps = lambda mean : 2 * (M * K + K * N ) + 4.0 * (M * N ) * (1e-9 ) / (mean * 1e-3 )
166180
167- return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
181+ return (gbps (mean_ms ), gbps (max_ms ), gbps (min_ms )), (tflops (mean_ms ), tflops (max_ms ), tflops (min_ms )), cv
168182
169183
170184if __name__ == '__main__' :
0 commit comments