Skip to content

Commit 61a7174

Browse files
committed
check cuda ver
1 parent 186b751 commit 61a7174

File tree

1 file changed

+69
-15
lines changed

1 file changed

+69
-15
lines changed

cuda_setup.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,60 @@ def find_in_path(name, path):
2828
return None
2929

3030

31+
def get_cuda_sm_list(cuda_ver):
32+
if "CUDA_SM_LIST" in os.environ:
33+
sm_list = os.environ["CUDA_SM_LIST"].split(",")
34+
else:
35+
sm_list = ["30", "52", "60", "61", "70", "75", "80", "86"]
36+
if cuda_ver >= 110:
37+
filter_list = ["30"]
38+
if cuda_ver == 110:
39+
filter_list += ["86"]
40+
else:
41+
filter_list = ["80", "86"]
42+
if cuda_ver < 100:
43+
filter_list += ["75"]
44+
if cuda_ver < 90:
45+
filter_list += ["70"]
46+
if cuda_ver < 80:
47+
filter_list += ["60", "61"]
48+
sm_list = [sm for sm in sm_list if sm not in filter_list]
49+
return sm_list
50+
51+
52+
def get_cuda_compute(cuda_ver):
53+
if "CUDA_COMPUTE" in os.environ:
54+
compute = os.environ["CUDA_COMPUTE"]
55+
else:
56+
if 70 <= cuda_ver < 80:
57+
compute = "52"
58+
if 80 <= cuda_ver < 90:
59+
compute = "61"
60+
if 90 <= cuda_ver < 100:
61+
compute = "70"
62+
if 100 <= cuda_ver < 110:
63+
compute = "75"
64+
if cuda_ver == 110:
65+
compute = "80"
66+
if cuda_ver == 111:
67+
compute = "86"
68+
return compute
69+
70+
71+
def get_cuda_arch(cuda_ver):
72+
if "CUDA_ARCH" in os.environ:
73+
arch = os.environ["CUDA_ARCH"]
74+
else:
75+
if 70 <= cuda_ver < 92:
76+
arch = "30"
77+
if 92 <= cuda_ver < 110:
78+
arch = "50"
79+
if cuda_ver == 110:
80+
arch = "52"
81+
if cuda_ver == "111":
82+
arch = "80"
83+
return arch
84+
3185
def locate_cuda():
3286
"""Locate the CUDA environment on the system
3387
If a valid cuda installation is found
@@ -60,22 +114,22 @@ def locate_cuda():
60114
'your path, or set $CUDA_HOME to enable CUDA extensions')
61115
return None
62116
home = os.path.dirname(os.path.dirname(nvcc))
63-
64117
cudaconfig = {'home': home,
65-
'nvcc': nvcc,
66-
'include': os.path.join(home, 'include'),
67-
'lib64': os.path.join(home, 'lib64')}
68-
post_args = [
69-
"-arch=sm_52",
70-
"-gencode=arch=compute_52,code=sm_52",
71-
"-gencode=arch=compute_60,code=sm_60",
72-
"-gencode=arch=compute_61,code=sm_61",
73-
"-gencode=arch=compute_70,code=sm_70",
74-
"-gencode=arch=compute_75,code=sm_75",
75-
"-gencode=arch=compute_80,code=sm_80",
76-
"-gencode=arch=compute_86,code=sm_86",
77-
"-gencode=arch=compute_86,code=compute_86",
78-
'--ptxas-options=-v', '-O2']
118+
'nvcc': nvcc,
119+
'include': os.path.join(home, 'include'),
120+
'lib64': os.path.join(home, 'lib64')}
121+
cuda_ver = os.path.basename(os.path.realpath(home)).split("-")[1].split(".")
122+
cuda_ver = 10 * int(cuda_ver[0]) + int(cuda_ver[1])
123+
assert cuda_ver >= 700, f"too low cuda ver {cuda_ver}"
124+
logging.info("cuda_ver: %s", cuda_ver)
125+
arch = get_cuda_arch(cuda_ver)
126+
sm_list = get_cuda_sm_list(cuda_ver)
127+
compute = get_cuda_compute(cuda_ver)
128+
post_args = [f"-arch=sm_{arch}"] + \
129+
[f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_list] + \
130+
[f"-gencode=arch=compute_{compute},code=compute_{compute}",
131+
"-ptxas-options=-v", "-O2"]
132+
logging.info("nvcc post args: %s", post_args)
79133
if HALF_PRECISION:
80134
post_args = [flag for flag in post_args if "52" not in flag]
81135

0 commit comments

Comments
 (0)