33import subprocess
44import sys
55
6- from distutils import log
7- from distutils .dir_util import remove_tree
8- from distutils .command .clean import clean as _clean
9- from distutils .command .build import build as _build
10-
11- from setuptools import setup
6+ from setuptools import setup , Extension
7+ from setuptools .command .build_ext import build_ext as _build_ext
8+ from setuptools ._distutils import log
9+ from setuptools ._distutils .dir_util import remove_tree
10+ from setuptools ._distutils .command .clean import clean as _clean
1211
1312import torch
1413
1514
15+ class CMakeExtension (Extension ):
16+
17+ def __init__ (self , name ):
18+ # don't invoke the original build_ext for this special extension
19+ super ().__init__ (name , sources = [])
20+
21+
1622class CMakeBuild ():
1723
18- def __init__ (self , build_type = "Debug" ):
24+ def __init__ (self , debug = False , dry_run = False ):
1925 self .current_dir = os .path .abspath (os .path .dirname (__file__ ))
2026 self .build_temp = self .current_dir + "/build/temp"
2127 self .extdir = self .current_dir + "/triton_kernels_benchmark"
22- self .build_type = build_type
28+ self .build_type = self . get_build_type ( debug )
2329 self .cmake_prefix_paths = [torch .utils .cmake_prefix_path ]
2430 self .use_ipex = False
31+ self .dry_run = dry_run
32+
33+ def get_build_type (self , debug ):
34+ DEBUG_OPTION = os .getenv ("DEBUG" , "0" )
35+ return "Debug" if debug or (DEBUG_OPTION == "1" ) else "Release"
2536
2637 def run (self ):
2738 self .check_ipex ()
@@ -41,7 +52,8 @@ def check_ipex(self):
4152
4253 def check_call (self , * popenargs , ** kwargs ):
4354 print (" " .join (popenargs [0 ]))
44- subprocess .check_call (* popenargs , ** kwargs )
55+ if not self .dry_run :
56+ subprocess .check_call (* popenargs , ** kwargs )
4557
4658 def build_extension (self ):
4759 ninja_dir = shutil .which ("ninja" )
@@ -94,38 +106,27 @@ def clean(self):
94106 os .path .dirname (__file__ )))
95107
96108
97- class build ( _build ):
109+ class build_ext ( _build_ext ):
98110
99111 def run (self ):
100- self .build_cmake ()
101- super ().run ()
102-
103- def build_cmake (self ):
104- DEBUG_OPTION = os .getenv ("DEBUG" , "0" )
105- debug = DEBUG_OPTION == "1"
106- if hasattr (self , "debug" ):
107- debug = debug or self .debug
108- build_type = "Debug" if debug else "Release"
109- cmake = CMakeBuild (build_type )
112+ cmake = CMakeBuild (debug = self .debug , dry_run = self .dry_run )
110113 cmake .run ()
114+ super ().run ()
111115
112116
113117class clean (_clean ):
114118
115119 def run (self ):
116- self .clean_cmake ()
117- super ().run ()
118-
119- def clean_cmake (self ):
120- cmake = CMakeBuild ()
120+ cmake = CMakeBuild (dry_run = self .dry_run )
121121 cmake .clean ()
122+ super ().run ()
122123
123124
124125setup (name = "triton-kernels-benchmark" , packages = [
125126 "triton_kernels_benchmark" ,
126127], package_dir = {
127128 "triton_kernels_benchmark" : "triton_kernels_benchmark" ,
128129}, package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]}, cmdclass = {
129- "build " : build ,
130+ "build_ext " : build_ext ,
130131 "clean" : clean ,
131- })
132+ }, ext_modules = [ CMakeExtension ( "triton_kernels_benchmark" )] )
0 commit comments