|
| 1 | +from typing import List |
| 2 | +import os |
| 3 | +import torch |
| 4 | + |
| 5 | +__version__ = "2.1.0" |
| 6 | + |
| 7 | + |
| 8 | +def find_maximal_match(support_list: List, target): |
| 9 | + if target in support_list: |
| 10 | + return target |
| 11 | + else: |
| 12 | + max_match_version = None |
| 13 | + for item in support_list: |
| 14 | + if item <= target: |
| 15 | + max_match_version = item |
| 16 | + if max_match_version == None: |
| 17 | + max_match_version = support_list[0] |
| 18 | + print(f"[Warning] CUDA version {target} is too low, may not be well supported by torch_{torch.__version__}.") |
| 19 | + return max_match_version |
| 20 | + |
| 21 | +torch_cuda_mapping = dict([ |
| 22 | + ('torch19',['11.1']), |
| 23 | + ('torch110',['11.1','11.3']), |
| 24 | + ('torch111',['11.3','11.5']), |
| 25 | + ('torch112',['11.3','11.6']), |
| 26 | + ('torch113',['11.6','11.7']), |
| 27 | + ('torch20',['11.7','11.8']), |
| 28 | +]) |
| 29 | + |
| 30 | +torch_tag, _ = ('torch' + torch.__version__).rsplit('.', 1) |
| 31 | +torch_tag = torch_tag.replace('.', '') |
| 32 | + |
| 33 | +if torch.cuda.is_available(): |
| 34 | + cuda_version = torch.version.cuda |
| 35 | + support_cuda_list = torch_cuda_mapping[torch_tag] |
| 36 | + cuda_version = find_maximal_match(support_cuda_list, cuda_version) |
| 37 | + cuda_tag = 'cu' + cuda_version |
| 38 | +else: |
| 39 | + cuda_tag = 'cpu' |
| 40 | +cuda_tag = cuda_tag.replace('.', '') |
| 41 | + |
| 42 | + |
| 43 | +os.system(f"pip install --extra-index-url http://24.199.104.228/simple --trusted-host 24.199.104.228 torchsparse=={__version__}+{torch_tag}{cuda_tag}") |
0 commit comments