Skip to content

Commit e4c05a5

Browse files
committed
Add ext_modules for install build
1 parent 81b7817 commit e4c05a5

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

benchmarks/setup.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,36 @@
33
import subprocess
44
import sys
55

6-
from distutils import log
7-
from distutils.dir_util import remove_tree
8-
from distutils.command.clean import clean as _clean
9-
from distutils.command.build import build as _build
10-
11-
from setuptools import setup
6+
from setuptools import setup, Extension
7+
from setuptools.command.build_ext import build_ext as _build_ext
8+
from setuptools._distutils import log
9+
from setuptools._distutils.dir_util import remove_tree
10+
from setuptools._distutils.command.clean import clean as _clean
1211

1312
import torch
1413

1514

15+
class CMakeExtension(Extension):
16+
17+
def __init__(self, name):
18+
# don't invoke the original build_ext for this special extension
19+
super().__init__(name, sources=[])
20+
21+
1622
class CMakeBuild():
1723

18-
def __init__(self, build_type="Debug"):
24+
def __init__(self, debug=False, dry_run=False):
1925
self.current_dir = os.path.abspath(os.path.dirname(__file__))
2026
self.build_temp = self.current_dir + "/build/temp"
2127
self.extdir = self.current_dir + "/triton_kernels_benchmark"
22-
self.build_type = build_type
28+
self.build_type = self.get_build_type(debug)
2329
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
2430
self.use_ipex = False
31+
self.dry_run = dry_run
32+
33+
def get_build_type(self, debug):
34+
DEBUG_OPTION = os.getenv("DEBUG", "0")
35+
return "Debug" if debug or (DEBUG_OPTION == "1") else "Release"
2536

2637
def run(self):
2738
self.check_ipex()
@@ -41,7 +52,8 @@ def check_ipex(self):
4152

4253
def check_call(self, *popenargs, **kwargs):
4354
print(" ".join(popenargs[0]))
44-
subprocess.check_call(*popenargs, **kwargs)
55+
if not self.dry_run:
56+
subprocess.check_call(*popenargs, **kwargs)
4557

4658
def build_extension(self):
4759
ninja_dir = shutil.which("ninja")
@@ -94,38 +106,27 @@ def clean(self):
94106
os.path.dirname(__file__)))
95107

96108

97-
class build(_build):
109+
class build_ext(_build_ext):
98110

99111
def run(self):
100-
self.build_cmake()
101-
super().run()
102-
103-
def build_cmake(self):
104-
DEBUG_OPTION = os.getenv("DEBUG", "0")
105-
debug = DEBUG_OPTION == "1"
106-
if hasattr(self, "debug"):
107-
debug = debug or self.debug
108-
build_type = "Debug" if debug else "Release"
109-
cmake = CMakeBuild(build_type)
112+
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
110113
cmake.run()
114+
super().run()
111115

112116

113117
class clean(_clean):
114118

115119
def run(self):
116-
self.clean_cmake()
117-
super().run()
118-
119-
def clean_cmake(self):
120-
cmake = CMakeBuild()
120+
cmake = CMakeBuild(dry_run=self.dry_run)
121121
cmake.clean()
122+
super().run()
122123

123124

124125
setup(name="triton-kernels-benchmark", packages=[
125126
"triton_kernels_benchmark",
126127
], package_dir={
127128
"triton_kernels_benchmark": "triton_kernels_benchmark",
128129
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
129-
"build": build,
130+
"build_ext": build_ext,
130131
"clean": clean,
131-
})
132+
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])

0 commit comments

Comments
 (0)