diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 3078e691b7..6a9a8baff9 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -650,7 +650,7 @@ def torch_matmul(a, b): N, K = b.shape bytes_per_elem = a.element_size() flops_str = f"flops{bytes_per_elem * 8}" - if is_cuda(): + if os.name != "nt": with proton.scope(f"torch [M={M}, N={N}, K={K}]", {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): c = torch.matmul(a, b.T) @@ -672,8 +672,7 @@ def bench_fn(label, reps, warmup_reps, fn, *args): print(f"Benchmarking {label}: ...", end="") for _ in range(warmup_reps): fn(*args) - #FIXME: Enable for XPU once proton support works. - if is_cuda(): + if os.name != "nt": with proton_context(): for _ in range(reps): fn(*args) @@ -783,11 +782,11 @@ def show_profile(precision, profile_name): validate(32, 32, 32, dtype) validate(8192, 8192, args.K_range[0], dtype) - if is_cuda(): + if os.name != "nt": proton.start("matmul", hook="triton") proton.deactivate() for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): bench(K, dtype) - if is_cuda(): + if os.name != "nt": proton.finalize() show_profile(args.prec, "matmul")