Skip to content

Commit 7a56f28

Browse files
kentang-mitys-2020Haotian Tang
authored
v2.1.0 installation (#210)
* [Major] Update README for v2.1.0 * [Minor] Update README.md * [Minor] Update README.md * [Minor] Update README.md * [Minor] Add installation. --------- Co-authored-by: ys-2020 <[email protected]> Co-authored-by: Haotian Tang <[email protected]>
1 parent 3ed2dfa commit 7a56f28

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

install.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List
2+
import os
3+
import torch
4+
5+
__version__ = "2.1.0"
6+
7+
8+
def find_maximal_match(support_list: List, target):
9+
if target in support_list:
10+
return target
11+
else:
12+
max_match_version = None
13+
for item in support_list:
14+
if item <= target:
15+
max_match_version = item
16+
if max_match_version == None:
17+
max_match_version = support_list[0]
18+
print(f"[Warning] CUDA version {target} is too low, may not be well supported by torch_{torch.__version__}.")
19+
return max_match_version
20+
21+
torch_cuda_mapping = dict([
22+
('torch19',['11.1']),
23+
('torch110',['11.1','11.3']),
24+
('torch111',['11.3','11.5']),
25+
('torch112',['11.3','11.6']),
26+
('torch113',['11.6','11.7']),
27+
('torch20',['11.7','11.8']),
28+
])
29+
30+
torch_tag, _ = ('torch' + torch.__version__).rsplit('.', 1)
31+
torch_tag = torch_tag.replace('.', '')
32+
33+
if torch.cuda.is_available():
34+
cuda_version = torch.version.cuda
35+
support_cuda_list = torch_cuda_mapping[torch_tag]
36+
cuda_version = find_maximal_match(support_cuda_list, cuda_version)
37+
cuda_tag = 'cu' + cuda_version
38+
else:
39+
cuda_tag = 'cpu'
40+
cuda_tag = cuda_tag.replace('.', '')
41+
42+
43+
os.system(f"pip install --extra-index-url http://24.199.104.228/simple --trusted-host 24.199.104.228 torchsparse=={__version__}+{torch_tag}{cuda_tag}")

0 commit comments

Comments
 (0)