diff --git a/Project.toml b/Project.toml index fc258555..13838457 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Adapt = "4" GPUArrays = "11.2.1" GPUCompiler = "1.6" -KernelAbstractions = "0.9.2" +KernelAbstractions = "0.9.38" LLVM = "9.1" LinearAlgebra = "1" OpenCL_jll = "=2024.10.24" diff --git a/lib/cl/state.jl b/lib/cl/state.jl index de44bd5d..bed42dc5 100644 --- a/lib/cl/state.jl +++ b/lib/cl/state.jl @@ -9,6 +9,7 @@ function clear_task_local_storage!() delete!(task_local_storage(), :CLPlatform) delete!(task_local_storage(), :CLQueue) delete!(task_local_storage(), :CLMemoryBackend) + delete!(task_local_storage(), :CLUnifiedMemoryBackend) end @@ -163,7 +164,7 @@ struct SVMBackend <: AbstractMemoryBackend end struct USMBackend <: AbstractMemoryBackend end struct BufferBackend <: AbstractMemoryBackend end -function supported_memory_backends(dev::Device) +function supported_memory_backends(dev::Device; unified=false) backends = AbstractMemoryBackend[] # unified shared memory is the first choice, as it gives us separate host and device @@ -177,7 +178,7 @@ function supported_memory_backends(dev::Device) # plain old buffers are always supported, but we only want to use them if we have the # buffer device address extension, which allows us to reference them by raw pointers. - if bda_supported(dev) + if !unified && bda_supported(dev) push!(backends, BufferBackend()) end @@ -187,7 +188,7 @@ function supported_memory_backends(dev::Device) push!(backends, SVMBackend()) end - if isempty(backends) + if !unified && isempty(backends) # as a last resort, use plain buffers without the ability to reference by pointer. # this severely limits compatibility, but it's better than nothing. push!(backends, BufferBackend()) @@ -196,8 +197,9 @@ function supported_memory_backends(dev::Device) return backends end -function default_memory_backend(dev::Device) - supported_backends = supported_memory_backends(dev) +function default_memory_backend(dev::Device; unified=false) + supported_backends = supported_memory_backends(dev; unified) + isempty(supported_backends) && return nothing backend_str = load_preference(OpenCL, "default_memory_backend") backend_str === nothing && return first(supported_backends) @@ -211,8 +213,7 @@ function default_memory_backend(dev::Device) else error("Unknown memory backend '$backend_str' requested") end - in(backend, supported_backends) ? backend : nothing - backend + return in(backend, supported_backends) ? backend : nothing end function memory_backend() @@ -230,6 +231,17 @@ function memory_backend() end end +function unified_memory_backend() + return get!(task_local_storage(), :CLUnifiedMemoryBackend) do + dev = device() + backend = default_memory_backend(dev; unified=true) + if backend === nothing + error("Device $(dev) does not support any of the available unified memory backends") + end + backend + end +end + ## per-task queues diff --git a/src/OpenCLKernels.jl b/src/OpenCLKernels.jl index b01565ce..e06102cf 100644 --- a/src/OpenCLKernels.jl +++ b/src/OpenCLKernels.jl @@ -17,9 +17,22 @@ export OpenCLBackend struct OpenCLBackend <: KA.GPU end -KA.allocate(::OpenCLBackend, ::Type{T}, dims::Tuple) where T = CLArray{T}(undef, dims) -KA.zeros(::OpenCLBackend, ::Type{T}, dims::Tuple) where T = OpenCL.zeros(T, dims) -KA.ones(::OpenCLBackend, ::Type{T}, dims::Tuple) where T = OpenCL.ones(T, dims) +function KA.allocate(::OpenCLBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T + if unified + memory_backend = cl.unified_memory_backend() + if memory_backend === cl.USMBackend() + return CLArray{T, length(dims), cl.UnifiedSharedMemory}(undef, dims) + elseif memory_backend === cl.SVMBackend() + return CLArray{T, length(dims), cl.SharedVirtualMemory}(undef, dims) + else + throw(ArgumentError("Unified memory not supported")) + end + else + return CLArray{T}(undef, dims) + end +end + +KA.supports_unified(::OpenCLBackend) = cl.default_memory_backend(cl.device(); unified=true) !== nothing KA.get_backend(::CLArray) = OpenCLBackend() # TODO should be non-blocking