Skip to content

Commit e4db3d9

Browse files
committed
fix assertion logging
1 parent 8430e8e commit e4db3d9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

cuda_setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,18 @@ def locate_cuda():
119119
'include': os.path.join(home, 'include'),
120120
'lib64': os.path.join(home, 'lib64')}
121121
cuda_ver = os.path.basename(os.path.realpath(home)).split("-")[1].split(".")
122-
cuda_ver = 10 * int(cuda_ver[0]) + int(cuda_ver[1])
123-
assert cuda_ver >= 70, f"too low cuda ver {cuda_ver}"
124-
logging.info("cuda_ver: %s", cuda_ver)
122+
major, minor = int(cuda_ver[0]), int(cuda_ver[1])
123+
cuda_ver = 10 * major + minor
124+
assert cuda_ver >= 70, f"too low cuda ver {major}.{minor}"
125+
print(f"cuda_ver: {major}.{minor}")
125126
arch = get_cuda_arch(cuda_ver)
126127
sm_list = get_cuda_sm_list(cuda_ver)
127128
compute = get_cuda_compute(cuda_ver)
128129
post_args = [f"-arch=sm_{arch}"] + \
129130
[f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_list] + \
130131
[f"-gencode=arch=compute_{compute},code=compute_{compute}",
131132
"-ptxas-options=-v", "-O2"]
132-
logging.info("nvcc post args: %s", post_args)
133+
print(f"nvcc post args: {post_args}")
133134
if HALF_PRECISION:
134135
post_args = [flag for flag in post_args if "52" not in flag]
135136

0 commit comments

Comments
 (0)