99 print ('You need to install pytorch first.' )
1010 sys .exit (1 )
1111
12- from subprocess import check_call
12+ from subprocess import check_call , check_output
1313from setuptools import setup , Extension , find_packages , distutils
1414from setuptools .command .build_ext import build_ext
1515from distutils .spawn import find_executable
16+ from distutils .version import LooseVersion
1617from sysconfig import get_paths
1718
1819import distutils .ccompiler
3233base_dir = os .path .dirname (os .path .abspath (__file__ ))
3334python_include_dir = get_paths ()['include' ]
3435
36+ # from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/__init__.py
37+ def which (thefile ):
38+ path = os .environ .get ("PATH" , os .defpath ).split (os .pathsep )
39+ for d in path :
40+ fname = os .path .join (d , thefile )
41+ fnames = [fname ]
42+ if sys .platform == 'win32' :
43+ exts = os .environ .get ('PATHEXT' , '' ).split (os .pathsep )
44+ fnames += [fname + ext for ext in exts ]
45+ for name in fnames :
46+ if os .access (name , os .F_OK | os .X_OK ) and not os .path .isdir (name ):
47+ return name
48+ return None
49+
50+ def get_cmake_command ():
51+ def _get_version (cmd ):
52+ for line in check_output ([cmd , '--version' ]).decode ('utf-8' ).split ('\n ' ):
53+ if 'version' in line :
54+ return LooseVersion (line .strip ().split (' ' )[2 ])
55+ raise RuntimeError ('no version found' )
56+ "Returns cmake command."
57+ cmake_command = 'cmake'
58+ if platform .system () == 'Windows' :
59+ return cmake_command
60+ cmake3 = which ('cmake3' )
61+ cmake = which ('cmake' )
62+ if cmake3 is not None and _get_version (cmake3 ) >= LooseVersion ("3.13.0" ):
63+ cmake_command = 'cmake3'
64+ return cmake_command
65+ elif cmake is not None and _get_version (cmake ) >= LooseVersion ("3.13.0" ):
66+ return cmake_command
67+ else :
68+ raise RuntimeError ('no cmake or cmake3 with version >= 3.13.0 found' )
3569
3670def _check_env_flag (name , default = '' ):
3771 return os .getenv (name , default ).upper () in ['ON' , '1' , 'YES' , 'TRUE' , 'Y' ]
@@ -155,7 +189,8 @@ def run(self):
155189 # Generate the code before globbing!
156190 generate_ipex_cpu_aten_code (base_dir )
157191
158- cmake = find_executable ('cmake3' ) or find_executable ('cmake' )
192+ cmake = get_cmake_command ()
193+
159194 if cmake is None :
160195 raise RuntimeError (
161196 "CMake must be installed to build the following extensions: " +
@@ -175,7 +210,7 @@ def build_extension(self, ext):
175210
176211 build_type = 'Release'
177212 use_ninja = False
178-
213+
179214 if _check_env_flag ('DEBUG' ):
180215 build_type = 'Debug'
181216
@@ -192,7 +227,7 @@ def build_extension(self, ext):
192227 '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=' + ext_dir ,
193228 '-DPYTHON_INCLUDE_DIR=' + python_include_dir ,
194229 '-DPYTORCH_INCLUDE_DIRS=' + pytorch_install_dir + "/include" ,
195- '-DPYTORCH_LIBRARY_DIRS=' + pytorch_install_dir + "/lib" ,
230+ '-DPYTORCH_LIBRARY_DIRS=' + pytorch_install_dir + "/lib" ,
196231 ]
197232
198233 if _check_env_flag ("IPEX_DISP_OP" ):
0 commit comments