Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 136 additions & 25 deletions C/CUDA/CUDA_Driver/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ else
missing
end

fork_preference = if haskey(preferences, "fork")
parse_preference(preferences["fork"])
elseif haskey(ENV, "JULIA_CUDA_FORK_VERSION_CHECK")
parse_preference(ENV["JULIA_CUDA_FORK_VERSION_CHECK"])
else
missing
end


libcuda_deps = [libcuda_debugger, libnvidia_nvvm, libnvidia_ptxjitcompiler]
libcuda_system = Sys.iswindows() ? "nvcuda" : "libcuda.so.1"

Expand Down Expand Up @@ -55,6 +64,76 @@ if Libdl.dlopen(libcuda_system, Libdl.RTLD_NOLOAD; throw_error=false) !== nothin
return
end

# helper function to load a driver, query its version, and optionally query device
# capabilities.
function inspect_driver_in_memory(driver, deps=String[]; inspect_devices=false)
DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75
DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76

loaded_libs = []
driver_version = nothing
device_capabilities = nothing
try
for dep in deps
dlib = Libdl.dlopen(dep; throw_error=false)
if dlib === nothing
return nothing,nothing
else
push!(loaded_libs, dlib)
end
end

library_handle = Libdl.dlopen(driver; throw_error=false)
library_handle === nothing && return nothing,nothing
push!(loaded_libs, library_handle)

cuInit = Libdl.dlsym(library_handle, "cuInit")
status = ccall(cuInit, Cint, (UInt32,), 0)
status == 0 || return nothing,nothing

cuDriverGetVersion = Libdl.dlsym(library_handle, "cuDriverGetVersion")
version = Ref{Cint}()
status = ccall(cuDriverGetVersion, Cint, (Ptr{Cint},), version)
status == 0 || return nothing,nothing
major, ver = divrem(version[], 1000)
minor, patch = divrem(ver, 10)

driver_version = VersionNumber(major, minor, patch)

device_capabilities = []
if inspect_devices
cuDeviceGetCount = Libdl.dlsym(library_handle, "cuDeviceGetCount")
device_count = Ref{Cint}()
status = ccall(cuDeviceGetCount, Cint, (Ptr{Cint},), device_count)
status == 0 || return nothing,nothing

cuDeviceGet = Libdl.dlsym(library_handle, "cuDeviceGet")
cuDeviceGetAttribute = Libdl.dlsym(library_handle, "cuDeviceGetAttribute")
for i in 1:device_count[]
device = Ref{Cint}()
status = ccall(cuDeviceGet, Cint, (Ptr{Cint}, Cint), device, i-1)
status == 0 || return nothing,nothing

major = Ref{Cint}()
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), major, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device[])
status == 0 || return nothing,nothing
minor = Ref{Cint}()
status = ccall(cuDeviceGetAttribute, Cint, (Ptr{Cint}, UInt32, Cint), minor, DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device[])
status == 0 || return nothing,nothing

push!(device_capabilities, VersionNumber(major[], minor[]))
end
end
finally
for lib in loaded_libs
# TODO(@apozharski) we should check if they were already loaded (perhaps with RTLD_NOLOAD)
Libdl.dlclose(lib)
end
end

return driver_version, device_capabilities
end

# helper function to load a driver, query its version, and optionally query device
# capabilities. needs to happen in a separate process because dlclose is unreliable.
function inspect_driver(driver, deps=String[]; inspect_devices=false)
Expand Down Expand Up @@ -148,6 +227,7 @@ function inspect_driver(driver, deps=String[]; inspect_devices=false)
end

# fetch driver details
if fork_preference === missing || fork_preference
compat_driver_task = @static if VERSION >= v"1.12-"
# XXX: avoid concurrent compilation (JuliaLang/julia#59834)
Threads.@spawn :samepool inspect_driver(libcuda_compat, libcuda_deps)
Expand All @@ -160,37 +240,68 @@ end
else
Threads.@spawn inspect_driver(libcuda_system; inspect_devices=true)
end
compat_driver_details = fetch(compat_driver_task)
if compat_driver_details === nothing
@debug "Failed to load forwards-compatible driver."
return
end
compat_driver_version = compat_driver_details::VersionNumber
@debug "Forwards compatible driver version: $compat_driver_version"
system_driver_details = fetch(system_driver_task)
if system_driver_details === nothing
@debug "Failed to load system driver."
return
end
system_driver_version = system_driver_details[1]::VersionNumber
device_capabilities = system_driver_details[2]::Vector{VersionNumber}
@debug "System driver version: $system_driver_version"

# determine if loading the forwards-compatible driver would exclude devices
for (dev, cap) in enumerate(device_capabilities)
# CUDA 12 deprecated Kepler
if compat_driver_version >= v"12" && system_driver_version < v"12" && v"3.0" <= cap <= v"3.5"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
compat_driver_details = fetch(compat_driver_task)
if compat_driver_details === nothing
@debug "Failed to load forwards-compatible driver."
return
end
compat_driver_version = compat_driver_details::VersionNumber
@debug "Forwards compatible driver version: $compat_driver_version"
system_driver_details = fetch(system_driver_task)
if system_driver_details === nothing
@debug "Failed to load system driver."
return
end
system_driver_version = system_driver_details[1]::VersionNumber
device_capabilities = system_driver_details[2]::Vector{VersionNumber}
@debug "System driver version: $system_driver_version"

# CUDA 13 deprecated Maxwell, Pascal, and Volta
if compat_driver_version >= v"13" && system_driver_version < v"13" && v"5.0" <= cap <= v"7.2"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
# determine if loading the forwards-compatible driver would exclude devices
for (dev, cap) in enumerate(device_capabilities)
# CUDA 12 deprecated Kepler
if compat_driver_version >= v"12" && system_driver_version < v"12" && v"3.0" <= cap <= v"3.5"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
return
end

# CUDA 13 deprecated Maxwell, Pascal, and Volta
if compat_driver_version >= v"13" && system_driver_version < v"13" && v"5.0" <= cap <= v"7.2"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
return
end
end
else
(compat_driver_details,compat_device_capabilities) = inspect_driver_in_memory(libcuda_compat, libcuda_deps)
if compat_driver_details === nothing
@debug "Failed to load forwards-compatible driver."
return
end
end
compat_driver_version = compat_driver_details::VersionNumber
@debug "Forwards compatible driver version: $compat_driver_version"
(system_driver_details,device_details) = inspect_driver_in_memory(libcuda_system; inspect_devices=true)
if system_driver_details === nothing
@debug "Failed to load system driver."
return
end
system_driver_version = system_driver_details::VersionNumber
device_capabilities = device_details
@debug "System driver version: $system_driver_version"

# determine if loading the forwards-compatible driver would exclude devices
for (dev, cap) in enumerate(device_capabilities)
# CUDA 12 deprecated Kepler
if compat_driver_version >= v"12" && system_driver_version < v"12" && v"3.0" <= cap <= v"3.5"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
return
end

# CUDA 13 deprecated Maxwell, Pascal, and Volta
if compat_driver_version >= v"13" && system_driver_version < v"13" && v"5.0" <= cap <= v"7.2"
@debug "Loading forwards-compatible driver would exclude device $dev with capability $cap"
return
end
end
end
# finally, load the forwards-compatible driver
@debug "Using forwards-compatible CUDA driver."
global libcuda = libcuda_compat
Expand Down