@@ -119,17 +119,18 @@ def locate_cuda():
119
119
'include' : os .path .join (home , 'include' ),
120
120
'lib64' : os .path .join (home , 'lib64' )}
121
121
cuda_ver = os .path .basename (os .path .realpath (home )).split ("-" )[1 ].split ("." )
122
- cuda_ver = 10 * int (cuda_ver [0 ]) + int (cuda_ver [1 ])
123
- assert cuda_ver >= 70 , f"too low cuda ver { cuda_ver } "
124
- logging .info ("cuda_ver: %s" , cuda_ver )
122
+ major , minor = int (cuda_ver [0 ]), int (cuda_ver [1 ])
123
+ cuda_ver = 10 * major + minor
124
+ assert cuda_ver >= 70 , f"too low cuda ver { major } .{ minor } "
125
+ print (f"cuda_ver: { major } .{ minor } " )
125
126
arch = get_cuda_arch (cuda_ver )
126
127
sm_list = get_cuda_sm_list (cuda_ver )
127
128
compute = get_cuda_compute (cuda_ver )
128
129
post_args = [f"-arch=sm_{ arch } " ] + \
129
130
[f"-gencode=arch=compute_{ sm } ,code=sm_{ sm } " for sm in sm_list ] + \
130
131
[f"-gencode=arch=compute_{ compute } ,code=compute_{ compute } " ,
131
132
"-ptxas-options=-v" , "-O2" ]
132
- logging . info ( "nvcc post args: %s" , post_args )
133
+ print ( f "nvcc post args: { post_args } " )
133
134
if HALF_PRECISION :
134
135
post_args = [flag for flag in post_args if "52" not in flag ]
135
136
0 commit comments