Skip to content

Commit 3d9bf6c

Browse files
committed
Enable Proton for python/tutorials/09-persistent-matmul.py
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent e8efee9 commit 3d9bf6c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def torch_matmul(a, b):
650650
N, K = b.shape
651651
bytes_per_elem = a.element_size()
652652
flops_str = f"flops{bytes_per_elem * 8}"
653-
if is_cuda():
653+
if os.name != "nt":
654654
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
655655
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
656656
c = torch.matmul(a, b.T)
@@ -672,8 +672,7 @@ def bench_fn(label, reps, warmup_reps, fn, *args):
672672
print(f"Benchmarking {label}: ...", end="")
673673
for _ in range(warmup_reps):
674674
fn(*args)
675-
#FIXME: Enable for XPU once proton support works.
676-
if is_cuda():
675+
if os.name != "nt":
677676
with proton_context():
678677
for _ in range(reps):
679678
fn(*args)
@@ -783,11 +782,11 @@ def show_profile(precision, profile_name):
783782

784783
validate(32, 32, 32, dtype)
785784
validate(8192, 8192, args.K_range[0], dtype)
786-
if is_cuda():
785+
if os.name != "nt":
787786
proton.start("matmul", hook="triton")
788787
proton.deactivate()
789788
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
790789
bench(K, dtype)
791-
if is_cuda():
790+
if os.name != "nt":
792791
proton.finalize()
793792
show_profile(args.prec, "matmul")

0 commit comments

Comments
 (0)