Skip to content

Commit efce869

Browse files
authored
Fix xetla import when built as wheel (#2589)
Fixes #2576.
1 parent 18f70d0 commit efce869

File tree

7 files changed

+18
-31
lines changed

7 files changed

+18
-31
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ repos:
118118
- --disable=line-too-long
119119
# Disable import-error: not everything can be imported when pre-commit runs
120120
- --disable=import-error
121+
- --disable=no-name-in-module
121122
# Disable unused-import: ruff has a corresponding check and supports "noqa: F401"
122123
- --disable=unused-import
123124
# Disable invalid_name: benchmarks use a lot of UPPER_SNAKE arguments

benchmarks/setup.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
# TODO: update once there is replacement for clean:
77
# https://github.com/pypa/setuptools/discussions/2838
88
from distutils import log # pylint: disable=[deprecated-module]
9-
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
10-
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]
119

1210
from setuptools import setup, Extension
1311
from setuptools.command.build_ext import build_ext as _build_ext
@@ -24,10 +22,10 @@ def __init__(self, name):
2422

2523
class CMakeBuild():
2624

27-
def __init__(self, debug=False, dry_run=False):
25+
def __init__(self, build_lib, build_temp, debug=False, dry_run=False):
2826
self.current_dir = os.path.abspath(os.path.dirname(__file__))
29-
self.build_temp = self.current_dir + "/build/temp"
30-
self.extdir = self.current_dir + "/triton_kernels_benchmark"
27+
self.build_temp = build_temp
28+
self.extdir = build_lib + "/triton_kernels_benchmark"
3129
self.build_type = self.get_build_type(debug)
3230
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
3331
self.use_ipex = False
@@ -101,30 +99,20 @@ def build_extension(self):
10199
self.check_call(["cmake"] + build_args)
102100
self.check_call(["cmake"] + install_args)
103101

104-
def clean(self):
105-
if os.path.exists(self.build_temp):
106-
remove_tree(self.build_temp, dry_run=self.dry_run)
107-
else:
108-
log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp,
109-
os.path.dirname(__file__)))
110-
111102

112103
class build_ext(_build_ext):
113104

114105
def run(self):
115-
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
106+
cmake = CMakeBuild(
107+
build_lib=self.build_lib,
108+
build_temp=self.build_temp,
109+
debug=self.debug,
110+
dry_run=self.dry_run,
111+
)
116112
cmake.run()
117113
super().run()
118114

119115

120-
class clean(_clean):
121-
122-
def run(self):
123-
cmake = CMakeBuild(dry_run=self.dry_run)
124-
cmake.clean()
125-
super().run()
126-
127-
128116
def get_git_commit_hash(length=8):
129117
try:
130118
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
@@ -151,11 +139,10 @@ def get_git_commit_hash(length=8):
151139
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
152140
cmdclass={
153141
"build_ext": build_ext,
154-
"clean": clean,
155142
},
156-
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
157-
extra_require={
158-
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
143+
ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")],
144+
extras_require={
145+
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"],
159146
"pytorch": ["torch>=2.6"],
160147
},
161148
)

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton.language as tl
55

66
import triton_kernels_benchmark as benchmark_suit
7-
import xetla_kernel
7+
from triton_kernels_benchmark import xetla_kernel
88

99
if benchmark_suit.USE_IPEX_OPTION:
1010
import intel_extension_for_pytorch # type: ignore # noqa: F401

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.runtime import driver
1414

1515
import triton_kernels_benchmark as benchmark_suit
16-
import xetla_kernel
16+
from triton_kernels_benchmark import xetla_kernel
1717

1818

1919
@torch.jit.script

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
import triton_kernels_benchmark as benchmark_suit
1616
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD
17-
18-
import xetla_kernel
17+
from triton_kernels_benchmark import xetla_kernel
1918

2019
if benchmark_suit.USE_IPEX_OPTION:
2120
import intel_extension_for_pytorch # type: ignore # noqa: F401

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import triton.language as tl
44

55
import triton_kernels_benchmark as benchmark_suit
6-
import xetla_kernel
6+
from triton_kernels_benchmark import xetla_kernel
77

88
if benchmark_suit.USE_IPEX_OPTION:
99
import intel_extension_for_pytorch # type: ignore # noqa: F401

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import triton.language as tl
1111

1212
import triton_kernels_benchmark as benchmark_suit
13-
import xetla_kernel
13+
from triton_kernels_benchmark import xetla_kernel
1414

1515
if benchmark_suit.USE_IPEX_OPTION:
1616
import intel_extension_for_pytorch # type: ignore # noqa: F401

0 commit comments

Comments
 (0)