Skip to content

Commit 3fdc5db

Browse files
malfetpytorchmergebot
authored andcommitted
Make CUDA preload logic more straightforward (pytorch#167046)
I.e. remove distinction between two cases, and always preload full set of libraries For some reason, when one uses `virtualenv` instead of `venv`, preloading `cudart` works, but it fails to find cudnn or cublasLT later on Fix it, by getting read of partial preload logic for one of the cases and always preload full set of libraries Test plan on stock Ubuntu: ``` pip install virtualenv virtualenv --symlinks -p python3.11 --prompt virtv venv-virt source venv-virt/bin/activate pip install torch python -c 'import torch' ``` Fixes pytorch#165812 Pull Request resolved: pytorch#167046 Approved by: https://github.com/atalman
1 parent cc477f6 commit 3fdc5db

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

torch/__init__.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]:
303303
return nvidia_lib_paths + lib_paths
304304

305305

306-
def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type]
307-
"""Preloads cuda deps if they could not be found otherwise."""
306+
def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type]
307+
"""Preloads cuda library if it could not be found otherwise."""
308308
# Should only be called on Linux if default path resolution have failed
309309
assert platform.system() == "Linux", "Should only be called on Linux"
310310

@@ -320,6 +320,39 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) ->
320320
ctypes.CDLL(lib_path)
321321

322322

323+
def _preload_cuda_deps(err: _Optional[OSError] = None) -> None:
324+
cuda_libs: dict[str, str] = {
325+
"cublas": "libcublas.so.*[0-9]",
326+
"cudnn": "libcudnn.so.*[0-9]",
327+
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
328+
"cuda_runtime": "libcudart.so.*[0-9]",
329+
"cuda_cupti": "libcupti.so.*[0-9]",
330+
"cufft": "libcufft.so.*[0-9]",
331+
"curand": "libcurand.so.*[0-9]",
332+
"nvjitlink": "libnvJitLink.so.*[0-9]",
333+
"cusparse": "libcusparse.so.*[0-9]",
334+
"cusparselt": "libcusparseLt.so.*[0-9]",
335+
"cusolver": "libcusolver.so.*[0-9]",
336+
"nccl": "libnccl.so.*[0-9]",
337+
"nvshmem": "libnvshmem_host.so.*[0-9]",
338+
"cufile": "libcufile.so.*[0-9]",
339+
}
340+
341+
# If error is passed, re-raise it if it's not about one of the abovementioned
342+
# libraries
343+
if err is not None and [
344+
lib for lib in cuda_libs.values() if lib.split(".", 1)[0] in err.args[0]
345+
]:
346+
raise err
347+
348+
# Otherwise, try to preload dependencies from site-packages
349+
for lib_folder, lib_name in cuda_libs.items():
350+
_preload_cuda_lib(lib_folder, lib_name)
351+
352+
# libnvToolsExt is Optional Dependency
353+
_preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False)
354+
355+
323356
# See Note [Global dependencies]
324357
def _load_global_deps() -> None:
325358
if platform.system() == "Windows":
@@ -346,43 +379,15 @@ def _load_global_deps() -> None:
346379
# libtorch_global_deps.so always depends in cudart, check if its installed and loaded
347380
if "libcudart.so" not in _maps:
348381
return
349-
# If all above-mentioned conditions are met, preload nvrtc and nvjitlink
350-
_preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]")
351-
_preload_cuda_deps("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]")
352-
_preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]")
382+
# If all above-mentioned conditions are met, preload CUDA dependencies
383+
_preload_cuda_deps()
353384
except Exception:
354385
pass
355386

356387
except OSError as err:
357-
# Can only happen for wheel with cuda libs as PYPI deps
388+
# Can happen for wheel with cuda libs as PYPI deps
358389
# As PyTorch is not purelib, but nvidia-*-cu12 is
359-
cuda_libs: dict[str, str] = {
360-
"cublas": "libcublas.so.*[0-9]",
361-
"cudnn": "libcudnn.so.*[0-9]",
362-
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
363-
"cuda_runtime": "libcudart.so.*[0-9]",
364-
"cuda_cupti": "libcupti.so.*[0-9]",
365-
"cufft": "libcufft.so.*[0-9]",
366-
"curand": "libcurand.so.*[0-9]",
367-
"nvjitlink": "libnvJitLink.so.*[0-9]",
368-
"cusparse": "libcusparse.so.*[0-9]",
369-
"cusparselt": "libcusparseLt.so.*[0-9]",
370-
"cusolver": "libcusolver.so.*[0-9]",
371-
"nccl": "libnccl.so.*[0-9]",
372-
"nvshmem": "libnvshmem_host.so.*[0-9]",
373-
"cufile": "libcufile.so.*[0-9]",
374-
}
375-
376-
is_cuda_lib_err = [
377-
lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0]
378-
]
379-
if not is_cuda_lib_err:
380-
raise err
381-
for lib_folder, lib_name in cuda_libs.items():
382-
_preload_cuda_deps(lib_folder, lib_name)
383-
384-
# libnvToolsExt is Optional Dependency
385-
_preload_cuda_deps("nvtx", "libnvToolsExt.so.*[0-9]", required=False)
390+
_preload_cuda_deps(err)
386391
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
387392

388393

0 commit comments

Comments
 (0)