33import subprocess
44import sys
55
6- from distutils import log
7- from distutils .dir_util import remove_tree
8- from distutils .command .clean import clean as _clean
9- from distutils .command .build import build as _build
6+ from distutils import log # noqa: W4901 # setuptools use it
7+ from distutils .dir_util import remove_tree # noqa: W4901 # setuptools use it
8+ from distutils .command .clean import clean as _clean # noqa: W4901 # setuptools use it
109
11- from setuptools import setup
10+ from setuptools import setup , Extension
11+ from setuptools .command .build_ext import build_ext as _build_ext
1212
1313import torch
1414
1515
16+ class CMakeExtension (Extension ):
17+
18+ def __init__ (self , name ):
19+ # don't invoke the original build_ext for this special extension
20+ super ().__init__ (name , sources = [])
21+
22+
1623class CMakeBuild ():
1724
18- def __init__ (self , build_type = "Debug" ):
25+ def __init__ (self , debug = False , dry_run = False ):
1926 self .current_dir = os .path .abspath (os .path .dirname (__file__ ))
2027 self .build_temp = self .current_dir + "/build/temp"
2128 self .extdir = self .current_dir + "/triton_kernels_benchmark"
22- self .build_type = build_type
29+ self .build_type = self . get_build_type ( debug )
2330 self .cmake_prefix_paths = [torch .utils .cmake_prefix_path ]
2431 self .use_ipex = False
32+ self .dry_run = dry_run
33+
34+ def get_build_type (self , debug ):
35+ DEBUG_OPTION = os .getenv ("DEBUG" , "0" )
36+ return "Debug" if debug or (DEBUG_OPTION == "1" ) else "Release"
2537
2638 def run (self ):
2739 self .check_ipex ()
@@ -41,7 +53,8 @@ def check_ipex(self):
4153
4254 def check_call (self , * popenargs , ** kwargs ):
4355 print (" " .join (popenargs [0 ]))
44- subprocess .check_call (* popenargs , ** kwargs )
56+ if not self .dry_run :
57+ subprocess .check_call (* popenargs , ** kwargs )
4558
4659 def build_extension (self ):
4760 ninja_dir = shutil .which ("ninja" )
@@ -94,38 +107,27 @@ def clean(self):
94107 os .path .dirname (__file__ )))
95108
96109
97- class build ( _build ):
110+ class build_ext ( _build_ext ):
98111
99112 def run (self ):
100- self .build_cmake ()
101- super ().run ()
102-
103- def build_cmake (self ):
104- DEBUG_OPTION = os .getenv ("DEBUG" , "0" )
105- debug = DEBUG_OPTION == "1"
106- if hasattr (self , "debug" ):
107- debug = debug or self .debug
108- build_type = "Debug" if debug else "Release"
109- cmake = CMakeBuild (build_type )
113+ cmake = CMakeBuild (debug = self .debug , dry_run = self .dry_run )
110114 cmake .run ()
115+ super ().run ()
111116
112117
113118class clean (_clean ):
114119
115120 def run (self ):
116- self .clean_cmake ()
117- super ().run ()
118-
119- def clean_cmake (self ):
120- cmake = CMakeBuild ()
121+ cmake = CMakeBuild (dry_run = self .dry_run )
121122 cmake .clean ()
123+ super ().run ()
122124
123125
124126setup (name = "triton-kernels-benchmark" , packages = [
125127 "triton_kernels_benchmark" ,
126128], package_dir = {
127129 "triton_kernels_benchmark" : "triton_kernels_benchmark" ,
128130}, package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]}, cmdclass = {
129- "build " : build ,
131+ "build_ext " : build_ext ,
130132 "clean" : clean ,
131- })
133+ }, ext_modules = [ CMakeExtension ( "triton_kernels_benchmark" )] )
0 commit comments