Skip to content

Commit 45cea57

Browse files
committed
Add ext_modules for install build
1 parent 432405b commit 45cea57

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

benchmarks/setup.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,39 @@
33
import subprocess
44
import sys
55

6-
from distutils import log
7-
from distutils.dir_util import remove_tree
8-
from distutils.command.clean import clean as _clean
9-
from distutils.command.build import build as _build
6+
# TODO: update once there is replacement for clean:
7+
# https://github.com/pypa/setuptools/discussions/2838
8+
from distutils import log # pylint: disable=[deprecated-module]
9+
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
10+
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]
1011

11-
from setuptools import setup
12+
from setuptools import setup, Extension
13+
from setuptools.command.build_ext import build_ext as _build_ext
1214

1315
import torch
1416

1517

18+
class CMakeExtension(Extension):
19+
20+
def __init__(self, name):
21+
# don't invoke the original build_ext for this special extension
22+
super().__init__(name, sources=[])
23+
24+
1625
class CMakeBuild():
1726

18-
def __init__(self, build_type="Debug"):
27+
def __init__(self, debug=False, dry_run=False):
1928
self.current_dir = os.path.abspath(os.path.dirname(__file__))
2029
self.build_temp = self.current_dir + "/build/temp"
2130
self.extdir = self.current_dir + "/triton_kernels_benchmark"
22-
self.build_type = build_type
31+
self.build_type = self.get_build_type(debug)
2332
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
2433
self.use_ipex = False
34+
self.dry_run = dry_run
35+
36+
def get_build_type(self, debug):
37+
DEBUG_OPTION = os.getenv("DEBUG", "0")
38+
return "Debug" if debug or (DEBUG_OPTION == "1") else "Release"
2539

2640
def run(self):
2741
self.check_ipex()
@@ -41,7 +55,8 @@ def check_ipex(self):
4155

4256
def check_call(self, *popenargs, **kwargs):
4357
print(" ".join(popenargs[0]))
44-
subprocess.check_call(*popenargs, **kwargs)
58+
if not self.dry_run:
59+
subprocess.check_call(*popenargs, **kwargs)
4560

4661
def build_extension(self):
4762
ninja_dir = shutil.which("ninja")
@@ -94,38 +109,27 @@ def clean(self):
94109
os.path.dirname(__file__)))
95110

96111

97-
class build(_build):
112+
class build_ext(_build_ext):
98113

99114
def run(self):
100-
self.build_cmake()
101-
super().run()
102-
103-
def build_cmake(self):
104-
DEBUG_OPTION = os.getenv("DEBUG", "0")
105-
debug = DEBUG_OPTION == "1"
106-
if hasattr(self, "debug"):
107-
debug = debug or self.debug
108-
build_type = "Debug" if debug else "Release"
109-
cmake = CMakeBuild(build_type)
115+
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
110116
cmake.run()
117+
super().run()
111118

112119

113120
class clean(_clean):
114121

115122
def run(self):
116-
self.clean_cmake()
117-
super().run()
118-
119-
def clean_cmake(self):
120-
cmake = CMakeBuild()
123+
cmake = CMakeBuild(dry_run=self.dry_run)
121124
cmake.clean()
125+
super().run()
122126

123127

124128
setup(name="triton-kernels-benchmark", packages=[
125129
"triton_kernels_benchmark",
126130
], package_dir={
127131
"triton_kernels_benchmark": "triton_kernels_benchmark",
128132
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
129-
"build": build,
133+
"build_ext": build_ext,
130134
"clean": clean,
131-
})
135+
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])

0 commit comments

Comments
 (0)