11#!/usr/bin/env python
22from __future__ import print_function
33
4+ TORCH_VERSION = '1.8.0'
5+ TORCH_IPEX_VERSION = '1.3.0'
6+
7+ # import torch
8+ import platform
9+ import pkg_resources
10+ import re
11+ from socket import timeout
12+ import subprocess
13+ import sys
14+ import os
15+ import urllib .request
16+
17+ try :
18+ from packaging import version
19+ except Exception :
20+ subprocess .check_call ([sys .executable , '-m' , 'pip' , 'install' , 'packaging' ])
21+ from packaging import version
22+
23+ installed_raw = {pkg for pkg in pkg_resources .working_set }
24+ installed = {}
25+ for i in installed_raw :
26+ installed [i .key ] = i .version
27+
28+ requires = {}
29+ requires_raw = {}
30+ try :
31+ with open ('requirements.txt' , 'r' ) as reader :
32+ for line in reader .readlines ():
33+ line_raw = line .replace ('\n ' , '' )
34+ line = line_raw .replace ('=' , '' )
35+ tmp = re .split ('[=<>]' , line )
36+ if len (tmp ) == 2 :
37+ requires [tmp [0 ]] = tmp [1 ]
38+ else :
39+ requires [tmp [0 ]] = ''
40+ requires_raw [tmp [0 ]] = line_raw
41+ except Exception :
42+ pass
43+
44+ restart = False
45+ for k in requires .keys ():
46+ if k in installed .keys ():
47+ if requires [k ] != '' and version .parse (installed [k ]) < version .parse (requires [k ]):
48+ subprocess .check_call ([sys .executable , '-m' , 'pip' , 'install' , requires_raw [k ]])
49+ if k == 'wheel' :
50+ restart = True
51+ else :
52+ subprocess .check_call ([sys .executable , '-m' , 'pip' , 'install' , k ])
53+ if k == 'wheel' :
54+ restart = True
55+ if restart :
56+ os .execv (sys .executable , ['python' ] + sys .argv )
57+ exit (1 )
58+
59+ TORCH_VERSION = os .getenv ('TORCH_VERSION' , TORCH_VERSION )
60+
461try :
562 import torch
663except ImportError as e :
7- print ('Unable to import torch. Error:' )
8- print ('\t ' , e )
9- print ('You need to install pytorch first.' )
10- sys .exit (1 )
64+ subprocess .check_call ([sys .executable , '-m' , 'pip' , 'install' , 'torch==' + TORCH_VERSION + '+cpu' , '-f' , 'https://download.pytorch.org/whl/torch_stable.html' ])
65+ import torch
66+
67+ PYTHON_VERSION = sys .version_info
68+ IS_WINDOWS = (platform .system () == 'Windows' )
69+ IS_DARWIN = (platform .system () == 'Darwin' )
70+ IS_LINUX = (platform .system () == 'Linux' )
71+
72+ TORCH_URL = 'torch @ https://download.pytorch.org/whl/cpu/torch-{0}%2Bcpu-cp{1}{2}-cp{1}{2}-linux_x86_64.whl' .format (TORCH_VERSION , PYTHON_VERSION .major , PYTHON_VERSION .minor )
73+ if IS_DARWIN :
74+ TORCH_URL = 'torch=={}' .format (TORCH_VERSION )
75+ else :
76+ OS_VER = 'linux_x86_64'
77+ if IS_WINDOWS :
78+ TORCH_URL = 'torch @ https://download.pytorch.org/whl/cpu/torch-{0}%2Bcpu-cp{1}{2}-cp{1}{2}-win_amd64.whl' .format (TORCH_VERSION , PYTHON_VERSION .major , PYTHON_VERSION .minor )
79+ OS_VER = 'win_amd64'
80+
81+ try :
82+ fp = urllib .request .urlopen ('https://download.pytorch.org/whl/torch_stable.html' , timeout = 30 )
83+ cont_bytes = fp .read ()
84+ cont = cont_bytes .decode ('utf8' ).replace ('\n ' , '' )
85+ fp .close ()
86+ lines = re .split (r'<br>' , cont )
87+
88+ for line in lines :
89+ matches = re .match ('<a href="(cpu\/torch-{0}.*cp{1}{2}.*{3}.*)">(.*)<\/a>' .format (TORCH_VERSION , PYTHON_VERSION .major , PYTHON_VERSION .minor , OS_VER ), line )
90+ if matches and len (matches .groups ()) == 2 :
91+ TORCH_URL = 'torch @ https://download.pytorch.org/whl/{}' .format (matches .group (2 ))
92+ break
93+ except Exception :
94+ pass
1195
1296from subprocess import check_call , check_output
1397from setuptools import setup , Extension , find_packages , distutils
22106import inspect
23107import multiprocessing
24108import multiprocessing .pool
25- import os
26- import platform
27- import re
28109import shutil
29- import subprocess
30- import sys
31110
32111pytorch_install_dir = os .path .dirname (os .path .abspath (torch .__file__ ))
33112base_dir = os .path .dirname (os .path .abspath (__file__ ))
@@ -86,20 +165,23 @@ def _get_env_backend():
86165
87166
88167def get_git_head_sha (base_dir ):
89- ipex_git_sha = subprocess .check_output (['git' , 'rev-parse' , 'HEAD' ],
90- cwd = base_dir ).decode ('ascii' ).strip ()
91- if os .path .isdir (os .path .join (base_dir , '..' , '.git' )):
92- torch_git_sha = subprocess .check_output (['git' , 'rev-parse' , 'HEAD' ],
93- cwd = os .path .join (
94- base_dir ,
95- '..' )).decode ('ascii' ).strip ()
96- else :
97- torch_git_sha = ''
168+ ipex_git_sha = ''
169+ torch_git_sha = ''
170+ try :
171+ ipex_git_sha = subprocess .check_output (['git' , 'rev-parse' , 'HEAD' ],
172+ cwd = base_dir ).decode ('ascii' ).strip ()
173+ if os .path .isdir (os .path .join (base_dir , '..' , '.git' )):
174+ torch_git_sha = subprocess .check_output (['git' , 'rev-parse' , 'HEAD' ],
175+ cwd = os .path .join (
176+ base_dir ,
177+ '..' )).decode ('ascii' ).strip ()
178+ except Exception :
179+ pass
98180 return ipex_git_sha , torch_git_sha
99181
100182
101183def get_build_version (ipex_git_sha ):
102- version = os .getenv ('TORCH_IPEX_VERSION' , '1.2.0' )
184+ version = os .getenv ('TORCH_IPEX_VERSION' , TORCH_IPEX_VERSION )
103185 if _check_env_flag ('VERSIONED_IPEX_BUILD' , default = '0' ):
104186 try :
105187 version += '+' + ipex_git_sha [:7 ]
@@ -271,11 +353,6 @@ def build_extension(self, ext):
271353
272354# Constant known variables used throughout this file
273355
274- # PyTorch installed library
275- IS_WINDOWS = (platform .system () == 'Windows' )
276- IS_DARWIN = (platform .system () == 'Darwin' )
277- IS_LINUX = (platform .system () == 'Linux' )
278-
279356
280357def make_relative_rpath (path ):
281358 if IS_DARWIN :
@@ -285,12 +362,17 @@ def make_relative_rpath(path):
285362 else :
286363 return '-Wl,-rpath,$ORIGIN/' + path
287364
365+ install_requires = [
366+ TORCH_URL ,
367+ ]
368+
288369setup (
289370 name = 'torch_ipex' ,
290371 version = version ,
291372 description = 'Intel PyTorch Extension' ,
292373 url = 'https://github.com/intel/intel-extension-for-pytorch' ,
293374 author = 'Intel/PyTorch Dev Team' ,
375+ install_requires = install_requires ,
294376 # Exclude the build files.
295377 #packages=find_packages(exclude=['build']),
296378 packages = [
0 commit comments