Skip to content

Commit 9d0d98a

Browse files
atalmanpytorchmergebot
authored andcommitted
Use cuda nvrtc so file based on cuda version used by torch (pytorch#163642)
Fixes pytorch#162367 Pull Request resolved: pytorch#163642 Approved by: https://github.com/msaroufim
1 parent 3b73841 commit 9d0d98a

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

torch/cuda/_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,22 @@ def _get_hiprtc_library() -> ctypes.CDLL:
7171

7272

7373
def _get_nvrtc_library() -> ctypes.CDLL:
74+
major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
7475
if sys.platform == "win32":
75-
return ctypes.CDLL("nvrtc64_120_0.dll")
76+
nvrtc_libs = [
77+
f"nvrtc64_{major_version}0_0.dll",
78+
]
7679
else:
77-
return ctypes.CDLL("libnvrtc.so")
80+
nvrtc_libs = [
81+
f"libnvrtc.so.{major_version}",
82+
"libnvrtc.so", # Fallback to unversioned
83+
]
84+
for lib_name in nvrtc_libs:
85+
try:
86+
return ctypes.CDLL(lib_name)
87+
except OSError:
88+
continue
89+
raise OSError("Could not find any NVRTC library")
7890

7991

8092
def _get_gpu_rtc_library() -> ctypes.CDLL:

0 commit comments

Comments
 (0)