Skip to content

Commit c4c0e28

Browse files
Use CPU memcpy for SharedStorage copyto! (#445)
1 parent 46298ca commit c4c0e28

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/array.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::Arra
401401
end
402402
return dest
403403
end
404+
function Base.unsafe_copyto!(::MTLDevice, dest::MtlArray{T,<:Any,Metal.SharedStorage}, doffs, src::Array{T}, soffs, n) where T
405+
# these copies are implemented using pure memcpy's, not API calls, so aren't ordered.
406+
synchronize()
407+
GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array,dest), doffs), pointer(src, soffs), n)
408+
return dest
409+
end
404410

405411
# GPU -> CPU
406412
function Base.unsafe_copyto!(dev::MTLDevice, dest::Array{T}, doffs, src::MtlArray{T}, soffs, n) where T
@@ -414,6 +420,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::Array{T}, doffs, src::MtlArra
414420
end
415421
return dest
416422
end
423+
function Base.unsafe_copyto!(::MTLDevice, dest::Array{T}, doffs, src::MtlArray{T,<:Any,Metal.SharedStorage}, soffs, n) where T
424+
# these copies are implemented using pure memcpy's, not API calls, so aren't ordered.
425+
synchronize()
426+
GC.@preserve src dest unsafe_copyto!(pointer(dest, doffs), pointer(unsafe_wrap(Array,src), soffs), n)
427+
return dest
428+
end
417429

418430
# GPU -> GPU
419431
function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlArray{T}, soffs, n) where T
@@ -427,6 +439,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlA
427439
end
428440
return dest
429441
end
442+
function Base.unsafe_copyto!(::MTLDevice, dest::MtlArray{T,<:Any,Metal.SharedStorage}, doffs, src::MtlArray{T,<:Any,Metal.SharedStorage}, soffs, n) where T
443+
# these copies are implemented using pure memcpy's, not API calls, so aren't ordered.
444+
synchronize()
445+
GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array,dest), doffs), pointer(unsafe_wrap(Array,src), soffs), n)
446+
return dest
447+
end
430448

431449

432450
## regular gpu array adaptor

test/array.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,30 @@ end
6969
@test collect(Metal.fill(1, 2, 2)) == ones(Float32, 2, 2)
7070
end
7171

72+
@testset "copyto!" begin
73+
@testset "$T, $S" for S in [Metal.PrivateStorage, Metal.SharedStorage],
74+
T in [Float16, Float32, Bool, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8]
75+
dim = (1000,17,10)
76+
A = rand(T,dim)
77+
mtlA = mtl(A;storage=S)
78+
79+
#cpu -> gpu
80+
res = Metal.zeros(T,dim;storage=S)
81+
copyto!(res,A)
82+
@test Array(res) == Array(A)
83+
84+
#gpu -> cpu
85+
res = zeros(T,dim)
86+
copyto!(res,mtlA)
87+
@test Array(res) == Array(mtlA)
88+
89+
#gpu -> gpu
90+
res = Metal.zeros(T,dim;storage=S)
91+
copyto!(res,mtlA)
92+
@test Array(res) == Array(mtlA)
93+
end
94+
end
95+
7296
check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
7397

7498
# There is some repetition to the GPUArrays tests to test for different storagemodes

0 commit comments

Comments
 (0)