@@ -303,8 +303,8 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]:
303303 return nvidia_lib_paths + lib_paths
304304
305305
306- def _preload_cuda_deps (lib_folder : str , lib_name : str , required : bool = True ) -> None : # type: ignore[valid-type]
307- """Preloads cuda deps if they could not be found otherwise."""
306+ def _preload_cuda_lib (lib_folder : str , lib_name : str , required : bool = True ) -> None : # type: ignore[valid-type]
307+ """Preloads cuda library if it could not be found otherwise."""
308308 # Should only be called on Linux if default path resolution have failed
309309 assert platform .system () == "Linux" , "Should only be called on Linux"
310310
@@ -320,6 +320,39 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) ->
320320 ctypes .CDLL (lib_path )
321321
322322
323+ def _preload_cuda_deps (err : _Optional [OSError ] = None ) -> None :
324+ cuda_libs : dict [str , str ] = {
325+ "cublas" : "libcublas.so.*[0-9]" ,
326+ "cudnn" : "libcudnn.so.*[0-9]" ,
327+ "cuda_nvrtc" : "libnvrtc.so.*[0-9]" ,
328+ "cuda_runtime" : "libcudart.so.*[0-9]" ,
329+ "cuda_cupti" : "libcupti.so.*[0-9]" ,
330+ "cufft" : "libcufft.so.*[0-9]" ,
331+ "curand" : "libcurand.so.*[0-9]" ,
332+ "nvjitlink" : "libnvJitLink.so.*[0-9]" ,
333+ "cusparse" : "libcusparse.so.*[0-9]" ,
334+ "cusparselt" : "libcusparseLt.so.*[0-9]" ,
335+ "cusolver" : "libcusolver.so.*[0-9]" ,
336+ "nccl" : "libnccl.so.*[0-9]" ,
337+ "nvshmem" : "libnvshmem_host.so.*[0-9]" ,
338+ "cufile" : "libcufile.so.*[0-9]" ,
339+ }
340+
341+ # If error is passed, re-raise it if it's not about one of the abovementioned
342+ # libraries
343+ if err is not None and [
344+ lib for lib in cuda_libs .values () if lib .split ("." , 1 )[0 ] in err .args [0 ]
345+ ]:
346+ raise err
347+
348+ # Otherwise, try to preload dependencies from site-packages
349+ for lib_folder , lib_name in cuda_libs .items ():
350+ _preload_cuda_lib (lib_folder , lib_name )
351+
352+ # libnvToolsExt is Optional Dependency
353+ _preload_cuda_lib ("nvtx" , "libnvToolsExt.so.*[0-9]" , required = False )
354+
355+
323356# See Note [Global dependencies]
324357def _load_global_deps () -> None :
325358 if platform .system () == "Windows" :
@@ -346,43 +379,15 @@ def _load_global_deps() -> None:
346379 # libtorch_global_deps.so always depends in cudart, check if its installed and loaded
347380 if "libcudart.so" not in _maps :
348381 return
349- # If all above-mentioned conditions are met, preload nvrtc and nvjitlink
350- _preload_cuda_deps ("cuda_nvrtc" , "libnvrtc.so.*[0-9]" )
351- _preload_cuda_deps ("cuda_nvrtc" , "libnvrtc-builtins.so.*[0-9]" )
352- _preload_cuda_deps ("nvjitlink" , "libnvJitLink.so.*[0-9]" )
382+ # If all above-mentioned conditions are met, preload CUDA dependencies
383+ _preload_cuda_deps ()
353384 except Exception :
354385 pass
355386
356387 except OSError as err :
357- # Can only happen for wheel with cuda libs as PYPI deps
388+ # Can happen for wheel with cuda libs as PYPI deps
358389 # As PyTorch is not purelib, but nvidia-*-cu12 is
359- cuda_libs : dict [str , str ] = {
360- "cublas" : "libcublas.so.*[0-9]" ,
361- "cudnn" : "libcudnn.so.*[0-9]" ,
362- "cuda_nvrtc" : "libnvrtc.so.*[0-9]" ,
363- "cuda_runtime" : "libcudart.so.*[0-9]" ,
364- "cuda_cupti" : "libcupti.so.*[0-9]" ,
365- "cufft" : "libcufft.so.*[0-9]" ,
366- "curand" : "libcurand.so.*[0-9]" ,
367- "nvjitlink" : "libnvJitLink.so.*[0-9]" ,
368- "cusparse" : "libcusparse.so.*[0-9]" ,
369- "cusparselt" : "libcusparseLt.so.*[0-9]" ,
370- "cusolver" : "libcusolver.so.*[0-9]" ,
371- "nccl" : "libnccl.so.*[0-9]" ,
372- "nvshmem" : "libnvshmem_host.so.*[0-9]" ,
373- "cufile" : "libcufile.so.*[0-9]" ,
374- }
375-
376- is_cuda_lib_err = [
377- lib for lib in cuda_libs .values () if lib .split ("." )[0 ] in err .args [0 ]
378- ]
379- if not is_cuda_lib_err :
380- raise err
381- for lib_folder , lib_name in cuda_libs .items ():
382- _preload_cuda_deps (lib_folder , lib_name )
383-
384- # libnvToolsExt is Optional Dependency
385- _preload_cuda_deps ("nvtx" , "libnvToolsExt.so.*[0-9]" , required = False )
390+ _preload_cuda_deps (err )
386391 ctypes .CDLL (global_deps_lib_path , mode = ctypes .RTLD_GLOBAL )
387392
388393
0 commit comments