Skip to content

Commit b3ddfca

Browse files
[Tutorials] Rename reference library name (#2452)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 76426a7 commit b3ddfca

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def matmul(a, b, activation=""):
433433
# We can now compare the performance of our kernel against that of cuBLAS or rocBLAS. Here we focus on square matrices,
434434
# but feel free to arrange this script as you wish to benchmark any other matrix shape.
435435

436-
ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'
436+
ref_lib = 'cuBLAS' if is_cuda() else 'oneDNN' if is_xpu() else 'rocBLAS'
437437

438438
configs = []
439439
for fp8_inputs in [False, True]:

python/tutorials/08-grouped-gemm.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
import triton.language as tl
3333

3434

35+
def is_cuda():
36+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
37+
38+
3539
@triton.autotune(
3640
configs=[
3741
triton.Config({
@@ -228,6 +232,9 @@ def torch_perf_fn(group_A, group_B):
228232
torch.matmul(a, b)
229233

230234

235+
ref_lib = 'cuBLAS' if is_cuda() else 'oneDNN'
236+
237+
231238
@triton.testing.perf_report(
232239
triton.testing.Benchmark(
233240
# argument names to use as an x-axis for the plot
@@ -236,9 +243,9 @@ def torch_perf_fn(group_A, group_B):
236243
line_arg='provider',
237244
# argument name whose value corresponds to a different line in the plot
238245
# possible values for `line_arg``
239-
line_vals=['cublas', 'triton'],
246+
line_vals=[ref_lib.lower(), 'triton'],
240247
# label name for the lines
241-
line_names=["cuBLAS", "Triton"],
248+
line_names=[ref_lib, "Triton"],
242249
# line styles
243250
styles=[('green', '-'), ('blue', '-')],
244251
ylabel="runtime(ms)", # label name for the y-axis
@@ -276,7 +283,7 @@ def benchmark(N, provider):
276283
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu")
277284

278285
quantiles = [0.5, 0.2, 0.8]
279-
if provider == 'cublas':
286+
if provider == ref_lib.lower():
280287
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)
281288
if provider == 'triton':
282289
ms, min_ms, max_ms = triton.testing.do_bench(

0 commit comments

Comments
 (0)