55import sys
66
77from setuptools import setup
8+ from setuptools .command .build_ext import build_ext as build_ext_orig
89
910import torch
1011
1718
1819class CMakeBuild ():
1920
20- def __init__ (self ):
21+ def __init__ (self , build_type = "Debug" ):
2122 self .current_dir = os .path .abspath (os .path .dirname (__file__ ))
2223 self .build_temp = self .current_dir + "/build/temp"
2324 self .extdir = self .current_dir + "/triton_kernels_benchmark"
25+ self .build_type = build_type
2426
2527 def run (self ):
2628 try :
@@ -36,9 +38,6 @@ def run(self):
3638 self .build_extension ()
3739
3840 def build_extension (self ):
39- # configuration
40- build_type = "Debug"
41-
4241 ninja_dir = shutil .which ("ninja" )
4342 # create build directories
4443 if not os .path .exists (self .build_temp ):
@@ -55,7 +54,7 @@ def build_extension(self):
5554 "-DCMAKE_VERBOSE_MAKEFILE=TRUE" ,
5655 "-DCMAKE_C_COMPILER=icx" ,
5756 "-DCMAKE_CXX_COMPILER=icpx" ,
58- "-DCMAKE_BUILD_TYPE=" + build_type ,
57+ "-DCMAKE_BUILD_TYPE=" + self . build_type ,
5958 "-S" ,
6059 self .current_dir ,
6160 "-B" ,
@@ -85,11 +84,23 @@ def build_extension(self):
8584 subprocess .check_call (["cmake" ] + install_args )
8685
8786
88- cmake = CMakeBuild ()
89- cmake .run ()
87+ class build_ext (build_ext_orig ):
88+
89+ def run (self ):
90+ self .build_cmake ()
91+ super ().run ()
92+
93+ def build_cmake (self ):
94+ DEBUG_OPTION = os .getenv ("DEBUG" , "0" )
95+ build_type = "Debug" if self .debug or DEBUG_OPTION == "1" else "Release"
96+ cmake = CMakeBuild (build_type )
97+ cmake .run ()
98+
9099
91100setup (name = "triton-kernels-benchmark" , packages = [
92101 "triton_kernels_benchmark" ,
93102], package_dir = {
94103 "triton_kernels_benchmark" : "triton_kernels_benchmark" ,
95- }, package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]})
104+ }, package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]}, cmdclass = {
105+ "build_ext" : build_ext ,
106+ })
0 commit comments