We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3b73841 commit 9d0d98aCopy full SHA for 9d0d98a
torch/cuda/_utils.py
@@ -71,10 +71,22 @@ def _get_hiprtc_library() -> ctypes.CDLL:
71
72
73
def _get_nvrtc_library() -> ctypes.CDLL:
74
+ major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
75
if sys.platform == "win32":
- return ctypes.CDLL("nvrtc64_120_0.dll")
76
+ nvrtc_libs = [
77
+ f"nvrtc64_{major_version}0_0.dll",
78
+ ]
79
else:
- return ctypes.CDLL("libnvrtc.so")
80
81
+ f"libnvrtc.so.{major_version}",
82
+ "libnvrtc.so", # Fallback to unversioned
83
84
+ for lib_name in nvrtc_libs:
85
+ try:
86
+ return ctypes.CDLL(lib_name)
87
+ except OSError:
88
+ continue
89
+ raise OSError("Could not find any NVRTC library")
90
91
92
def _get_gpu_rtc_library() -> ctypes.CDLL:
0 commit comments