Skip to content

Commit 15b0ebf

Browse files
Support MtlArray backed by Array memory (#320)
Co-authored-by: Tim Besard <[email protected]>
1 parent f9776db commit 15b0ebf

File tree

4 files changed

+115
-13
lines changed

4 files changed

+115
-13
lines changed

lib/mtl/buffer.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,36 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
3434
end
3535

3636
function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr;
37-
storage=Managed, hazard_tracking=DefaultTracking,
37+
nocopy=false, storage=Shared, hazard_tracking=DefaultTracking,
3838
cache_mode=DefaultCPUCache)
39-
storage == Private && error("Can't create a Private copy-allocated buffer.")
39+
storage == Private && error("Cannot allocate-and-initialize a Private buffer")
4040
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode
4141

4242
@assert 0 < bytesize <= dev.maxBufferLength
43-
ptr = alloc_buffer(dev, bytesize, opts, ptr)
43+
ptr = if nocopy
44+
alloc_buffer_nocopy(dev, bytesize, opts, ptr)
45+
else
46+
alloc_buffer(dev, bytesize, opts, ptr)
47+
end
4448

4549
return MTLBuffer(ptr)
4650
end
4751

52+
const PAGESIZE = ccall(:getpagesize, Cint, ())
53+
function can_alloc_nocopy(ptr::Ptr, bytesize::Integer)
54+
# newBufferWithBytesNoCopy has several restrictions:
55+
## the pointer has to be page-aligned
56+
if Int64(ptr) % PAGESIZE != 0
57+
return false
58+
end
59+
## the new buffer needs to be page-aligned
60+
## XXX: on macOS 14, this doesn't seem required; is this a documentation issue?
61+
if bytesize % PAGESIZE != 0
62+
return false
63+
end
64+
return true
65+
end
66+
4867
# from device
4968
alloc_buffer(dev::MTLDevice, bytesize, opts) =
5069
@objc [dev::id{MTLDevice} newBufferWithLength:bytesize::NSUInteger
@@ -53,6 +72,14 @@ alloc_buffer(dev::MTLDevice, bytesize, opts, ptr::Ptr) =
5372
@objc [dev::id{MTLDevice} newBufferWithBytes:ptr::Ptr{Cvoid}
5473
length:bytesize::NSUInteger
5574
options:opts::MTLResourceOptions]::id{MTLBuffer}
75+
function alloc_buffer_nocopy(dev::MTLDevice, bytesize, opts, ptr::Ptr)
76+
can_alloc_nocopy(ptr, bytesize) ||
77+
throw(ArgumentError("Cannot allocate nocopy buffer from non-aligned memory"))
78+
@objc [dev::id{MTLDevice} newBufferWithBytesNoCopy:ptr::Ptr{Cvoid}
79+
length:bytesize::NSUInteger
80+
options:opts::MTLResourceOptions
81+
deallocator:nil::id{Object}]::id{MTLBuffer}
82+
end
5683

5784
# from heap
5885
alloc_buffer(dev::MTLHeap, bytesize, opts) =

src/array.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
9191
end
9292
end
9393

94+
# Create MtlArray from MTLBuffer
95+
function MtlArray{T,N}(buf::B, dims::Dims{N}; kwargs...) where {B<:MTLBuffer,T,N}
96+
data = DataRef(buf) do buf
97+
free(buf)
98+
end
99+
return MtlArray{T,N}(data, dims; kwargs...)
100+
end
101+
94102
unsafe_free!(a::MtlArray) = GPUArrays.unsafe_free!(a.data)
95103

96104
device(A::MtlArray) = A.data[].device
@@ -491,6 +499,14 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false)
491499
return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own)
492500
end
493501

502+
function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, arr::Array, dims=size(arr);
503+
dev=current_device(), kwargs...) where {T,N}
504+
GC.@preserve arr begin
505+
buf = MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...)
506+
return A(buf, Dims(dims))
507+
end
508+
end
509+
494510
## resizing
495511

496512
"""

src/memory.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, src::Ptr{T}, N::Int
3535
storage_type = dst.buffer.storageMode
3636
if storage_type == MTL.MTLStorageModePrivate
3737
# stage through a shared buffer
38-
# shared = alloc(dev, N*sizeof(T), src; storage=Shared)
39-
# unsafe_copyto!(dev, dst, pointer(shared), N; queue, async=false)
40-
# free(shared)
41-
tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared) #CPU -> GPU (Shared)
42-
unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private)
38+
nocopy = MTL.can_alloc_nocopy(src, N*sizeof(T))
39+
tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared, nocopy)
40+
41+
# copy to the private buffer
42+
unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N;
43+
queue, async=(nocopy && async))
4344
free(tmp_buf)
4445
elseif storage_type == MTL.MTLStorageModeShared
4546
unsafe_copyto!(convert(Ptr{T}, dst), src, N)
@@ -54,12 +55,22 @@ end
5455
function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPtr{T}, N::Integer;
5556
queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T
5657
storage_type = src.buffer.storageMode
57-
if storage_type == MTL.MTLStorageModePrivate
58+
if storage_type == MTL.MTLStorageModePrivate
5859
# stage through a shared buffer
59-
shared = alloc(dev, N*sizeof(T); storage=Shared)
60-
unsafe_copyto!(dev, MtlPtr{T}(shared, 0), MtlPtr{T}(src.buffer, src.offset), N; queue, async=false)
61-
unsafe_copyto!(dst, convert(Ptr{T}, shared), N)
62-
free(shared)
60+
nocopy = MTL.can_alloc_nocopy(dst, N*sizeof(T))
61+
tmp_buf = if nocopy
62+
alloc(dev, N*sizeof(T), dst; storage=Shared, nocopy)
63+
else
64+
alloc(dev, N*sizeof(T); storage=Shared)
65+
end
66+
unsafe_copyto!(dev, MtlPtr{T}(tmp_buf, 0), MtlPtr{T}(src.buffer, src.offset), N;
67+
queue, async=(nocopy && async))
68+
69+
# copy from the shared buffer
70+
if !nocopy
71+
unsafe_copyto!(dst, convert(Ptr{T}, tmp_buf), N)
72+
end
73+
free(tmp_buf)
6374
elseif storage_type == MTL.MTLStorageModeShared
6475
unsafe_copyto!(dst, convert(Ptr{T}, src), N)
6576
elseif storage_type == MTL.MTLStorageModeManaged

test/array.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,52 @@ end
305305
@test length(b) == 1
306306
end
307307

308+
function _alignedvec(::Type{T}, n::Integer, alignment::Integer=16384) where {T}
309+
ispow2(alignment) || throw(ArgumentError("$alignment is not a power of 2"))
310+
alignment sizeof(Int) || throw(ArgumentError("$alignment is not a multiple of $(sizeof(Int))"))
311+
isbitstype(T) || throw(ArgumentError("$T is not a bitstype"))
312+
p = Ref{Ptr{T}}()
313+
err = ccall(:posix_memalign, Cint, (Ref{Ptr{T}}, Csize_t, Csize_t), p, alignment, n*sizeof(T))
314+
iszero(err) || throw(OutOfMemoryError())
315+
return unsafe_wrap(Array, p[], n, own=true)
316+
end
317+
318+
@testset "unsafe_wrap" begin
319+
# Create page-aligned vector for testing
320+
arr1 = _alignedvec(Float32, 16384*2);
321+
fill!(arr1, zero(eltype(arr1)))
322+
marr1 = unsafe_wrap(MtlVector{Float32}, arr1);
323+
324+
@test all(arr1 .== 0)
325+
@test all(marr1 .== 0)
326+
327+
# XXX: Test fails when ordered as shown
328+
# @test all(arr1 .== 1)
329+
# @test all(marr1 .== 1)
330+
marr1 .+= 1;
331+
@test all(marr1 .== 1)
332+
@test all(arr1 .== 1)
333+
334+
arr1 .+= 1;
335+
@test all(marr1 .== 2)
336+
@test all(arr1 .== 2)
337+
338+
marr2 = Metal.zeros(Float32, 18000; storage=Shared);
339+
arr2 = unsafe_wrap(Vector{Float32}, marr2);
340+
341+
@test all(arr2 .== 0)
342+
@test all(marr2 .== 0)
343+
344+
# XXX: Test fails when ordered as shown
345+
# @test all(arr2 .== 1)
346+
# @test all(marr2 .== 1)
347+
marr2 .+= 1;
348+
@test all(marr2 .== 1)
349+
@test all(arr2 .== 1)
350+
351+
arr2 .+= 1;
352+
@test all(arr2 .== 2)
353+
@test all(marr2 .== 2)
354+
end
355+
308356
end

0 commit comments

Comments
 (0)