Skip to content

Commit fd9e25e

Browse files
Support KA unified memory (#630)
1 parent 1e2cb1c commit fd9e25e

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ExprTools = "0.1"
4040
GPUArrays = "11.2.1"
4141
GPUCompiler = "0.26, 0.27, 1"
4242
GPUToolbox = "0.1, 0.2, 0.3"
43-
KernelAbstractions = "0.9.1"
43+
KernelAbstractions = "0.9.38"
4444
LLVM = "7.2, 8, 9"
4545
LLVMDowngrader_jll = "0.6"
4646
LinearAlgebra = "1"

src/MetalKernels.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MetalKernels
22

33
using ..Metal
4-
using ..Metal: @device_override
4+
using ..Metal: @device_override, DefaultStorageMode, SharedStorage
55

66
import KernelAbstractions as KA
77

@@ -22,9 +22,9 @@ The `KernelAbstractions` backend for running on Metal GPUs.
2222
struct MetalBackend <: KA.GPU
2323
end
2424

25-
KA.allocate(::MetalBackend, ::Type{T}, dims::Tuple) where T = MtlArray{T}(undef, dims)
26-
KA.zeros(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.zeros(T, dims)
27-
KA.ones(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.ones(T, dims)
25+
KA.allocate(::MetalBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = MtlArray{T, length(dims), unified ? SharedStorage : DefaultStorageMode}(undef, dims)
26+
KA.zeros(::MetalBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = Metal.zeros(T, dims; storage=unified ? SharedStorage : DefaultStorageMode)
27+
KA.ones(::MetalBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = Metal.ones(T, dims; storage=unified ? SharedStorage : DefaultStorageMode)
2828

2929
KA.get_backend(::MtlArray) = MetalBackend()
3030
KA.synchronize(::MetalBackend) = synchronize()
@@ -33,6 +33,7 @@ KA.functional(::MetalBackend) = Metal.functional()
3333

3434
KA.supports_float64(::MetalBackend) = false
3535
KA.supports_atomics(::MetalBackend) = false
36+
KA.supports_unified(::MetalBackend) = true
3637

3738
Adapt.adapt_storage(::MetalBackend, a::Array) = Adapt.adapt(MtlArray, a)
3839
Adapt.adapt_storage(::MetalBackend, a::MtlArray) = a

0 commit comments

Comments
 (0)