11module MetalKernels
22
33using .. Metal
4- using .. Metal: @device_override
4+ using .. Metal: @device_override , DefaultStorageMode, SharedStorage
55
66import KernelAbstractions as KA
77
@@ -22,9 +22,9 @@ The `KernelAbstractions` backend for running on Metal GPUs.
2222struct MetalBackend <: KA.GPU
2323end
2424
25- KA. allocate (:: MetalBackend , :: Type{T} , dims:: Tuple ) where T = MtlArray {T} (undef, dims)
26- KA. zeros (:: MetalBackend , :: Type{T} , dims:: Tuple ) where T = Metal. zeros (T, dims)
27- KA. ones (:: MetalBackend , :: Type{T} , dims:: Tuple ) where T = Metal. ones (T, dims)
25+ KA. allocate (:: MetalBackend , :: Type{T} , dims:: Tuple ; unified :: Bool = false ) where T = MtlArray {T, length(dims), unified ? SharedStorage : DefaultStorageMode } (undef, dims)
26+ KA. zeros (:: MetalBackend , :: Type{T} , dims:: Tuple ; unified :: Bool = false ) where T = Metal. zeros (T, dims; storage = unified ? SharedStorage : DefaultStorageMode )
27+ KA. ones (:: MetalBackend , :: Type{T} , dims:: Tuple ; unified :: Bool = false ) where T = Metal. ones (T, dims; storage = unified ? SharedStorage : DefaultStorageMode )
2828
2929KA. get_backend (:: MtlArray ) = MetalBackend ()
3030KA. synchronize (:: MetalBackend ) = synchronize ()
@@ -33,6 +33,7 @@ KA.functional(::MetalBackend) = Metal.functional()
3333
3434KA. supports_float64 (:: MetalBackend ) = false
3535KA. supports_atomics (:: MetalBackend ) = false
36+ KA. supports_unified (:: MetalBackend ) = true
3637
3738Adapt. adapt_storage (:: MetalBackend , a:: Array ) = Adapt. adapt (MtlArray, a)
3839Adapt. adapt_storage (:: MetalBackend , a:: MtlArray ) = a
0 commit comments