Skip to content

Commit 74645a2

Browse files
authored
Enable 09-persistent-matmul.py on Win; don't import proton (#4489)
To fix (from https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15570525521/job/43845094182): ```bash File "C:\gh15570525521\python\tutorials\09-persistent-matmul.py", line 28, in <module> import triton.profiler as proton File "C:\ar\_work\intel-xpu-backend-for-triton\intel-xpu-backend-for-triton\.venv\lib\site-packages\triton\profiler\__init__.py", line 2, in <module> from .scope import scope, cpu_timed_scope, enter_scope, exit_scope File "C:\ar\_work\intel-xpu-backend-for-triton\intel-xpu-backend-for-triton\.venv\lib\site-packages\triton\profiler\scope.py", line 7, in <module> from triton._C.libproton import proton as libproton ModuleNotFoundError: No module named 'triton._C.libproton' ``` * BMG CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15593957844 (failed) * https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15607019612 (passed) --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7bc086a commit 74645a2

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@
1919
Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090.
2020
"""
2121

22+
import os
2223
import argparse
2324
import itertools
2425

2526
import torch
2627
import triton
2728
import triton.language as tl
28-
import triton.profiler as proton
2929
from triton.tools.tensor_descriptor import TensorDescriptor
3030
from contextlib import contextmanager
3131

3232
from typing import Optional
3333

34+
if os.name != "nt":
35+
import triton.profiler as proton
36+
3437
DEVICE = triton.runtime.driver.active.get_active_torch_device()
3538
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
3639

@@ -625,8 +628,11 @@ def torch_matmul(a, b):
625628
N, K = b.shape
626629
bytes_per_elem = a.element_size()
627630
flops_str = f"flops{bytes_per_elem * 8}"
628-
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
629-
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
631+
if is_cuda():
632+
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
633+
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
634+
c = torch.matmul(a, b.T)
635+
else:
630636
c = torch.matmul(a, b.T)
631637
return c
632638

0 commit comments

Comments
 (0)