3232import triton .language as tl
3333
3434
35+ def is_cuda ():
36+ return triton .runtime .driver .active .get_current_target ().backend == "cuda"
37+
38+
3539@triton .autotune (
3640 configs = [
3741 triton .Config ({
@@ -228,6 +232,9 @@ def torch_perf_fn(group_A, group_B):
228232 torch .matmul (a , b )
229233
230234
235+ ref_lib = 'cuBLAS' if is_cuda () else 'oneDNN'
236+
237+
231238@triton .testing .perf_report (
232239 triton .testing .Benchmark (
233240 # argument names to use as an x-axis for the plot
@@ -236,9 +243,9 @@ def torch_perf_fn(group_A, group_B):
236243 line_arg = 'provider' ,
237244 # argument name whose value corresponds to a different line in the plot
238245 # possible values for `line_arg``
239- line_vals = ['cublas' , 'triton' ],
246+ line_vals = [ref_lib . lower () , 'triton' ],
240247 # label name for the lines
241- line_names = ["cuBLAS" , "Triton" ],
248+ line_names = [ref_lib , "Triton" ],
242249 # line styles
243250 styles = [('green' , '-' ), ('blue' , '-' )],
244251 ylabel = "runtime(ms)" , # label name for the y-axis
@@ -276,7 +283,7 @@ def benchmark(N, provider):
276283 d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = "xpu" )
277284
278285 quantiles = [0.5 , 0.2 , 0.8 ]
279- if provider == 'cublas' :
286+ if provider == ref_lib . lower () :
280287 ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch_perf_fn (group_A , group_B ), quantiles = quantiles )
281288 if provider == 'triton' :
282289 ms , min_ms , max_ms = triton .testing .do_bench (
0 commit comments