Skip to content

Commit 15829b0

Browse files
authored
Better API for Unified Memory (#305)
1 parent 91d72d0 commit 15829b0

File tree

7 files changed

+114
-55
lines changed

7 files changed

+114
-55
lines changed

docs/src/usage/array.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ can construct `MtlArray`s in the same way as regular `Array` objects:
2828

2929
```jldoctest
3030
julia> MtlArray{Int}(undef, 2)
31-
2-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
31+
2-element MtlVector{Int64, Private}:
3232
0
3333
0
3434
3535
julia> MtlArray{Int}(undef, (1,2))
36-
1×2 MtlMatrix{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
36+
1×2 MtlMatrix{Int64, Private}:
3737
0 0
3838
3939
julia> similar(ans)
40-
1×2 MtlMatrix{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
40+
1×2 MtlMatrix{Int64, Private}:
4141
0 0
4242
```
4343

@@ -46,7 +46,7 @@ Copying memory to or from the GPU can be expressed using constructors as well, o
4646

4747
```jldoctest
4848
julia> a = MtlArray([1,2])
49-
2-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
49+
2-element MtlVector{Int64, Private}:
5050
1
5151
2
5252
@@ -73,11 +73,11 @@ perform simple element-wise operations you can use `map` or `broadcast`:
7373
julia> a = MtlArray{Float32}(undef, (1,2));
7474
7575
julia> a .= 5
76-
1×2 MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
76+
1×2 MtlMatrix{Float32, Private}:
7777
5.0 5.0
7878
7979
julia> map(sin, a)
80-
1×2 MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
80+
1×2 MtlMatrix{Float32, Private}:
8181
-0.958924 -0.958924
8282
```
8383

@@ -86,23 +86,23 @@ To reduce the dimensionality of arrays, Metal.jl implements the various flavours
8686

8787
```jldoctest
8888
julia> a = Metal.ones(2,3)
89-
2×3 MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
89+
2×3 MtlMatrix{Float32, Private}:
9090
1.0 1.0 1.0
9191
1.0 1.0 1.0
9292
9393
julia> reduce(+, a)
9494
6.0f0
9595
9696
julia> mapreduce(sin, *, a; dims=2)
97-
2×1 MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
97+
2×1 MtlMatrix{Float32, Private}:
9898
0.59582335
9999
0.59582335
100100
101101
julia> b = Metal.zeros(1)
102-
1-element MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
102+
1-element MtlVector{Float32, Private}:
103103
0.0
104104
105105
julia> Base.mapreducedim!(identity, +, b, a)
106-
1×1 MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}:
106+
1×1 MtlMatrix{Float32, Private}:
107107
6.0
108108
```

lib/mtl/buffer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
Base.sizeof(buf::MTLBuffer) = Int(buf.length)
1616

1717
function Base.convert(::Type{Ptr{T}}, buf::MTLBuffer) where {T}
18-
buf.storageMode == Private && error("Cannot access the contents of a private buffer")
18+
buf.storageMode == MTLStorageModePrivate && error("Cannot access the contents of a private buffer")
1919
convert(Ptr{T}, buf.contents)
2020
end
2121

@@ -25,7 +25,7 @@ end
2525
function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
2626
storage=Private, hazard_tracking=DefaultTracking,
2727
cache_mode=DefaultCPUCache)
28-
opts = storage | hazard_tracking | cache_mode
28+
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode
2929

3030
@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
3131
ptr = alloc_buffer(dev, bytesize, opts)
@@ -37,7 +37,7 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer, ptr::Ptr;
3737
storage=Managed, hazard_tracking=DefaultTracking,
3838
cache_mode=DefaultCPUCache)
3939
storage == Private && error("Can't create a Private copy-allocated buffer.")
40-
opts = storage | hazard_tracking | cache_mode
40+
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode
4141

4242
@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
4343
ptr = alloc_buffer(dev, bytesize, opts, ptr)

lib/mtl/storage_type.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,30 @@
1-
abstract type StorageMode end
2-
export Shared, Managed, Private
1+
export Shared, Managed, Private, CPUStorage
32
export ReadUsage, WriteUsage, ReadWriteUsage
43

54
# Metal Has 4 storage types
65
# Shared -> Buffer in Host memory, accessed by the GPU. Requires no sync
76
# Managed -> Mirrored memory buffers in host and GPU. Requires syncing
87
# Private -> Memory in Device, not accessible by Host.
98
# Memoryless -> iOS stuff. ignore it
10-
module AS
11-
import ..StorageMode
129

10+
abstract type StorageMode end
1311
struct Shared <: StorageMode end
1412
struct Managed <: StorageMode end
1513
struct Private <: StorageMode end
1614
struct Memoryless <: StorageMode end
17-
end
18-
19-
const CPUStorage = Union{AS.Shared,AS.Managed}
20-
Base.convert(::Type{MTLStorageMode}, ::Type{AS.Shared}) = MTLStorageModeShared
21-
Base.convert(::Type{MTLStorageMode}, ::Type{AS.Managed}) = MTLStorageModeManaged
22-
Base.convert(::Type{MTLStorageMode}, ::Type{AS.Private}) = MTLStorageModePrivate
23-
Base.convert(::Type{MTLStorageMode}, ::Type{AS.Memoryless}) = MTLStorageModeMemoryless
2415

25-
Base.convert(::Type{MTLResourceOptions}, ::Type{AS.Shared}) = MTLResourceStorageModeShared
26-
Base.convert(::Type{MTLResourceOptions}, ::Type{AS.Managed}) = MTLResourceStorageModeManaged
27-
Base.convert(::Type{MTLResourceOptions}, ::Type{AS.Private}) = MTLResourceStorageModePrivate
28-
Base.convert(::Type{MTLResourceOptions}, ::Type{AS.Memoryless}) = MTLResourceStorageModeMemoryless
16+
const CPUStorage = Union{Shared,Managed}
17+
Base.convert(::Type{MTLStorageMode}, ::Type{Shared}) = MTLStorageModeShared
18+
Base.convert(::Type{MTLStorageMode}, ::Type{Managed}) = MTLStorageModeManaged
19+
Base.convert(::Type{MTLStorageMode}, ::Type{Private}) = MTLStorageModePrivate
20+
Base.convert(::Type{MTLStorageMode}, ::Type{Memoryless}) = MTLStorageModeMemoryless
2921

30-
Base.convert(::Type{MTLResourceOptions}, SM::MTL.MTLStorageMode) = MTLResourceOptions(UInt(SM) << 4)
22+
Base.convert(::Type{MTLResourceOptions}, ::Type{Shared}) = MTLResourceStorageModeShared
23+
Base.convert(::Type{MTLResourceOptions}, ::Type{Managed}) = MTLResourceStorageModeManaged
24+
Base.convert(::Type{MTLResourceOptions}, ::Type{Private}) = MTLResourceStorageModePrivate
25+
Base.convert(::Type{MTLResourceOptions}, ::Type{Memoryless}) = MTLResourceStorageModeMemoryless
3126

32-
const Shared = MTLResourceStorageModeShared
33-
const Managed = MTLResourceStorageModeManaged
34-
const Private = MTLResourceStorageModePrivate
35-
const Memoryless = MTLResourceStorageModeMemoryless
27+
Base.convert(::Type{MTLResourceOptions}, SM::MTLStorageMode) = MTLResourceOptions(UInt(SM) << 4)
3628

3729
const DefaultCPUCache = MTLResourceCPUCacheModeDefaultCache
3830
const CombinedWriteCPUCache = MTLResourceCPUCacheModeWriteCombined

src/array.jl

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# host array
22

3-
export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl
3+
export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private
44

55
function hasfieldcount(@nospecialize(dt))
66
try
@@ -77,8 +77,16 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
7777
function MtlArray{T,N}(data::DataRef{<:MTLBuffer}, dims::Dims{N};
7878
maxsize::Int=prod(dims) * sizeof(T), offset::Int=0) where {T,N}
7979
check_eltype(T)
80-
S = convert(MTL.MTLResourceOptions, data[].storageMode)
81-
obj = new{T,N,S}(copy(data), maxsize, offset, dims)
80+
storagemode = data[].storageMode
81+
if storagemode == MTL.MTLStorageModeShared
82+
obj = new{T,N,Shared}(copy(data), maxsize, offset, dims)
83+
elseif storagemode == MTL.MTLStorageModeManaged
84+
obj = new{T,N,Managed}(copy(data), maxsize, offset, dims)
85+
elseif storagemode == MTL.MTLStorageModePrivate
86+
obj = new{T,N,Private}(copy(data), maxsize, offset, dims)
87+
elseif storagemode == MTL.MTLStorageModeMemoryless
88+
obj = new{T,N,Memoryless}(copy(data), maxsize, offset, dims)
89+
end
8290
finalizer(unsafe_free!, obj)
8391
end
8492
end
@@ -90,6 +98,10 @@ device(A::MtlArray) = A.data[].device
9098
storagemode(x::MtlArray) = storagemode(typeof(x))
9199
storagemode(::Type{<:MtlArray{<:Any,<:Any,S}}) where {S} = S
92100

101+
is_shared(a::MtlArray) = storagemode(a) == Shared
102+
is_managed(a::MtlArray) = storagemode(a) == Managed
103+
is_private(a::MtlArray) = storagemode(a) == Private
104+
is_memoryless(a::MtlArray) = storagemode(a) == Memoryless
93105

94106
## convenience constructors
95107

@@ -144,15 +156,42 @@ Base.elsize(::Type{<:MtlArray{T}}) where {T} = sizeof(T)
144156
Base.size(x::MtlArray) = x.dims
145157
Base.sizeof(x::MtlArray) = Base.elsize(x) * length(x)
146158

147-
Base.pointer(x::MtlArray{T}) where {T} = Base.unsafe_convert(MtlPointer{T}, x)
148-
@inline function Base.pointer(x::MtlArray{T}, i::Integer) where T
149-
Base.unsafe_convert(MtlPointer{T}, x) + Base._memory_offset(x, i)
159+
@inline function Base.pointer(x::MtlArray{T}, i::Integer=1; storage=Private) where {T}
160+
PT = if storage == Private
161+
MtlPointer{T}
162+
elseif storage == Shared || storage == Managed
163+
Ptr{T}
164+
else
165+
error("unknown memory type")
166+
end
167+
Base.unsafe_convert(PT, x) + Base._memory_offset(x, i)
150168
end
151169

152-
Base.unsafe_convert(::Type{Ptr{S}}, x::MtlArray{T}) where {S, T} =
153-
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
154-
Base.unsafe_convert(::Type{MtlPointer{T}}, x::MtlArray) where {T} =
155-
MtlPointer{T}(x.data[], x.offset*Base.elsize(x))
170+
171+
function Base.unsafe_convert(::Type{MtlPointer{T}}, x::MtlArray) where {T}
172+
buf = x.data[]
173+
MtlPointer{T}(buf, x.offset*Base.elsize(x))
174+
end
175+
176+
function Base.unsafe_convert(::Type{Ptr{S}}, x::MtlArray{T}) where {S, T}
177+
buf = x.data[]
178+
if is_private(x)
179+
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
180+
end
181+
convert(Ptr{T}, buf) + x.offset*Base.elsize(x)
182+
end
183+
184+
185+
## indexing
186+
function Base.getindex(x::MtlArray{T,N,S}, I::Int) where {T,N,S<:Union{Shared,Managed}}
187+
@boundscheck checkbounds(x, I)
188+
unsafe_load(pointer(x, I; storage=S))
189+
end
190+
191+
function Base.setindex!(x::MtlArray{T,N,S}, v, I::Int) where {T,N,S<:Union{Shared,Managed}}
192+
@boundscheck checkbounds(x, I)
193+
unsafe_store!(pointer(x, I; storage=S), v)
194+
end
156195

157196

158197
## interop with other arrays
@@ -354,7 +393,7 @@ Uses Adapt.jl to act inside some wrapper structs.
354393
355394
```jldoctests
356395
julia> mtl(ones(3)')
357-
1×3 adjoint(::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}) with eltype Float32:
396+
1×3 adjoint(::MtlVector{Float32, Private}) with eltype Float32:
358397
1.0 1.0 1.0
359398
360399
julia> mtl(zeros(1,3); storage=Shared)
@@ -365,13 +404,13 @@ julia> mtl(1:3)
365404
1:3
366405
367406
julia> MtlArray(1:3)
368-
3-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
407+
3-element MtlVector{Int64, Private}:
369408
1
370409
2
371410
3
372411
373412
julia> mtl[1,2,3]
374-
3-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate}:
413+
3-element MtlVector{Int64, Private}:
375414
1
376415
2
377416
3
@@ -433,8 +472,9 @@ Base.unsafe_convert(::Type{MTL.MTLBuffer}, A::PermutedDimsArray) =
433472

434473
## unsafe_wrap
435474

436-
Base.unsafe_wrap(t::Type{<:Array}, arr::MtlArray, dims; own=false) =
437-
unsafe_wrap(t, arr.data[], dims; own=own)
475+
function Base.unsafe_wrap(::Type{<:Array}, arr::MtlArray{T,N}, dims=size(arr); own=false) where {T,N}
476+
return unsafe_wrap(Array{T,N}, arr.data[], dims; own=own)
477+
end
438478

439479
function Base.unsafe_wrap(t::Type{<:Array{T}}, buf::MTLBuffer, dims; own=false) where T
440480
ptr = convert(Ptr{T}, buf)

src/memory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N:
4545
unsafe_copyto!(convert(Ptr{T}, dst), src, N)
4646
elseif storage_type == MTL.MTLStorageModeManaged
4747
unsafe_copyto!(convert(Ptr{T}, dst), src, N)
48-
MTL.DidModifyRange!(dst, 1:N)
48+
MTL.DidModifyRange!(dst.buffer, 1:N)
4949
end
5050
return dst
5151
end

src/random.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
77
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)
88

99
# GPUArrays out-of-place
10-
rand(T::Type, dims::Dims; storage::MTL.MTLResourceOptions=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...))
11-
randn(T::Type, dims::Dims; storage::MTL.MTLResourceOptions=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...)
10+
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...))
11+
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...)
1212

1313
# support all dimension specifications
14-
rand(T::Type, dim1::Integer, dims::Integer...; storage::MTL.MTLResourceOptions=DefaultStorageMode) =
14+
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
1515
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...))
16-
randn(T::Type, dim1::Integer, dims::Integer...; storage::MTL.MTLResourceOptions=DefaultStorageMode, kwargs...) =
16+
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) =
1717
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
1818

1919
# untyped out-of-place
20-
rand(dim1::Integer, dims::Integer...; storage::MTL.MTLResourceOptions=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...))
21-
randn(dim1::Integer, dims::Integer...; storage::MTL.MTLResourceOptions=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
20+
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...))
21+
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
2222

2323
# seeding
2424
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed)

test/array.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,35 @@ check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
129129
# private storage errors.
130130
if SM == Metal.Private
131131
let arr_mtl = Metal.zeros(Float32, dim...; storage=Private)
132+
@test is_private(arr_mtl) && !is_shared(arr_mtl) && !is_managed(arr_mtl)
132133
@test_throws "Cannot access the contents of a private buffer" arr_cpu = unsafe_wrap(Array{Float32}, arr_mtl, dim)
133134
end
135+
136+
let b = rand(Float32, 10)
137+
arr_mtl = mtl(b, storage=Private)
138+
@test_throws ErrorException arr_mtl[1]
139+
@test Metal.@allowscalar arr_mtl[1] == b[1]
140+
end
141+
elseif SM == Metal.Shared
142+
let arr_mtl = Metal.zeros(Float32, dim...; storage=Shared)
143+
@test !is_private(arr_mtl) && is_shared(arr_mtl) && !is_managed(arr_mtl)
144+
@test unsafe_wrap(Array{Float32}, arr_mtl) isa Array{Float32}
145+
end
146+
147+
let b = rand(Float32, 10)
148+
arr_mtl = mtl(b, storage=Shared)
149+
@test arr_mtl[1] == b[1]
150+
end
151+
elseif SM == Metal.Managed
152+
let arr_mtl = Metal.zeros(Float32, dim...; storage=Managed)
153+
@test !is_private(arr_mtl) && !is_shared(arr_mtl) && is_managed(arr_mtl)
154+
@test unsafe_wrap(Array{Float32}, arr_mtl) isa Array{Float32}
155+
end
156+
157+
let b = rand(Float32, 10)
158+
arr_mtl = mtl(b, storage=Managed)
159+
@test arr_mtl[1] == b[1]
160+
end
134161
end
135162
end
136163

0 commit comments

Comments
 (0)