Skip to content

Commit 6b2f1c7

Browse files
committed
Add ext_modules for install build
1 parent 81b7817 commit 6b2f1c7

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

benchmarks/setup.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,37 @@
33
import subprocess
44
import sys
55

6-
from distutils import log
7-
from distutils.dir_util import remove_tree
8-
from distutils.command.clean import clean as _clean
9-
from distutils.command.build import build as _build
6+
from distutils import log # noqa: W4901 # setuptools use it
7+
from distutils.dir_util import remove_tree # noqa: W4901 # setuptools use it
8+
from distutils.command.clean import clean as _clean # noqa: W4901 # setuptools use it
109

11-
from setuptools import setup
10+
from setuptools import setup, Extension
11+
from setuptools.command.build_ext import build_ext as _build_ext
1212

1313
import torch
1414

1515

16+
class CMakeExtension(Extension):
17+
18+
def __init__(self, name):
19+
# don't invoke the original build_ext for this special extension
20+
super().__init__(name, sources=[])
21+
22+
1623
class CMakeBuild():
1724

18-
def __init__(self, build_type="Debug"):
25+
def __init__(self, debug=False, dry_run=False):
1926
self.current_dir = os.path.abspath(os.path.dirname(__file__))
2027
self.build_temp = self.current_dir + "/build/temp"
2128
self.extdir = self.current_dir + "/triton_kernels_benchmark"
22-
self.build_type = build_type
29+
self.build_type = self.get_build_type(debug)
2330
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
2431
self.use_ipex = False
32+
self.dry_run = dry_run
33+
34+
def get_build_type(self, debug):
35+
DEBUG_OPTION = os.getenv("DEBUG", "0")
36+
return "Debug" if debug or (DEBUG_OPTION == "1") else "Release"
2537

2638
def run(self):
2739
self.check_ipex()
@@ -41,7 +53,8 @@ def check_ipex(self):
4153

4254
def check_call(self, *popenargs, **kwargs):
4355
print(" ".join(popenargs[0]))
44-
subprocess.check_call(*popenargs, **kwargs)
56+
if not self.dry_run:
57+
subprocess.check_call(*popenargs, **kwargs)
4558

4659
def build_extension(self):
4760
ninja_dir = shutil.which("ninja")
@@ -94,38 +107,27 @@ def clean(self):
94107
os.path.dirname(__file__)))
95108

96109

97-
class build(_build):
110+
class build_ext(_build_ext):
98111

99112
def run(self):
100-
self.build_cmake()
101-
super().run()
102-
103-
def build_cmake(self):
104-
DEBUG_OPTION = os.getenv("DEBUG", "0")
105-
debug = DEBUG_OPTION == "1"
106-
if hasattr(self, "debug"):
107-
debug = debug or self.debug
108-
build_type = "Debug" if debug else "Release"
109-
cmake = CMakeBuild(build_type)
113+
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
110114
cmake.run()
115+
super().run()
111116

112117

113118
class clean(_clean):
114119

115120
def run(self):
116-
self.clean_cmake()
117-
super().run()
118-
119-
def clean_cmake(self):
120-
cmake = CMakeBuild()
121+
cmake = CMakeBuild(dry_run=self.dry_run)
121122
cmake.clean()
123+
super().run()
122124

123125

124126
setup(name="triton-kernels-benchmark", packages=[
125127
"triton_kernels_benchmark",
126128
], package_dir={
127129
"triton_kernels_benchmark": "triton_kernels_benchmark",
128130
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
129-
"build": build,
131+
"build_ext": build_ext,
130132
"clean": clean,
131-
})
133+
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])

0 commit comments

Comments
 (0)