@@ -28,6 +28,60 @@ def find_in_path(name, path):
28
28
return None
29
29
30
30
31
+ def get_cuda_sm_list (cuda_ver ):
32
+ if "CUDA_SM_LIST" in os .environ :
33
+ sm_list = os .environ ["CUDA_SM_LIST" ].split ("," )
34
+ else :
35
+ sm_list = ["30" , "52" , "60" , "61" , "70" , "75" , "80" , "86" ]
36
+ if cuda_ver >= 110 :
37
+ filter_list = ["30" ]
38
+ if cuda_ver == 110 :
39
+ filter_list += ["86" ]
40
+ else :
41
+ filter_list = ["80" , "86" ]
42
+ if cuda_ver < 100 :
43
+ filter_list += ["75" ]
44
+ if cuda_ver < 90 :
45
+ filter_list += ["70" ]
46
+ if cuda_ver < 80 :
47
+ filter_list += ["60" , "61" ]
48
+ sm_list = [sm for sm in sm_list if sm not in filter_list ]
49
+ return sm_list
50
+
51
+
52
+ def get_cuda_compute (cuda_ver ):
53
+ if "CUDA_COMPUTE" in os .environ :
54
+ compute = os .environ ["CUDA_COMPUTE" ]
55
+ else :
56
+ if 70 <= cuda_ver < 80 :
57
+ compute = "52"
58
+ if 80 <= cuda_ver < 90 :
59
+ compute = "61"
60
+ if 90 <= cuda_ver < 100 :
61
+ compute = "70"
62
+ if 100 <= cuda_ver < 110 :
63
+ compute = "75"
64
+ if cuda_ver == 110 :
65
+ compute = "80"
66
+ if cuda_ver == 111 :
67
+ compute = "86"
68
+ return compute
69
+
70
+
71
+ def get_cuda_arch (cuda_ver ):
72
+ if "CUDA_ARCH" in os .environ :
73
+ arch = os .environ ["CUDA_ARCH" ]
74
+ else :
75
+ if 70 <= cuda_ver < 92 :
76
+ arch = "30"
77
+ if 92 <= cuda_ver < 110 :
78
+ arch = "50"
79
+ if cuda_ver == 110 :
80
+ arch = "52"
81
+ if cuda_ver == 111 :
82
+ arch = "80"
83
+ return arch
84
+
31
85
def locate_cuda ():
32
86
"""Locate the CUDA environment on the system
33
87
If a valid cuda installation is found
@@ -60,22 +114,23 @@ def locate_cuda():
60
114
'your path, or set $CUDA_HOME to enable CUDA extensions' )
61
115
return None
62
116
home = os .path .dirname (os .path .dirname (nvcc ))
63
-
64
117
cudaconfig = {'home' : home ,
65
- 'nvcc' : nvcc ,
66
- 'include' : os .path .join (home , 'include' ),
67
- 'lib64' : os .path .join (home , 'lib64' )}
68
- post_args = [
69
- "-arch=sm_52" ,
70
- "-gencode=arch=compute_52,code=sm_52" ,
71
- "-gencode=arch=compute_60,code=sm_60" ,
72
- "-gencode=arch=compute_61,code=sm_61" ,
73
- "-gencode=arch=compute_70,code=sm_70" ,
74
- "-gencode=arch=compute_75,code=sm_75" ,
75
- "-gencode=arch=compute_80,code=sm_80" ,
76
- "-gencode=arch=compute_86,code=sm_86" ,
77
- "-gencode=arch=compute_86,code=compute_86" ,
78
- '--ptxas-options=-v' , '-O2' ]
118
+ 'nvcc' : nvcc ,
119
+ 'include' : os .path .join (home , 'include' ),
120
+ 'lib64' : os .path .join (home , 'lib64' )}
121
+ cuda_ver = os .path .basename (os .path .realpath (home )).split ("-" )[1 ].split ("." )
122
+ major , minor = int (cuda_ver [0 ]), int (cuda_ver [1 ])
123
+ cuda_ver = 10 * major + minor
124
+ assert cuda_ver >= 70 , f"too low cuda ver { major } .{ minor } "
125
+ print (f"cuda_ver: { major } .{ minor } " )
126
+ arch = get_cuda_arch (cuda_ver )
127
+ sm_list = get_cuda_sm_list (cuda_ver )
128
+ compute = get_cuda_compute (cuda_ver )
129
+ post_args = [f"-arch=sm_{ arch } " ] + \
130
+ [f"-gencode=arch=compute_{ sm } ,code=sm_{ sm } " for sm in sm_list ] + \
131
+ [f"-gencode=arch=compute_{ compute } ,code=compute_{ compute } " ,
132
+ "--ptxas-options=-v" , "-O2" ]
133
+ print (f"nvcc post args: { post_args } " )
79
134
if HALF_PRECISION :
80
135
post_args = [flag for flag in post_args if "52" not in flag ]
81
136
0 commit comments