Skip to content

Commit 67d668c

Browse files
authored
Use GPU blit for large SharedStorage GPU→GPU copies (>32MB) (#716)
1 parent 1c1115e commit 67d668c

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/array.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,19 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlA
440440
end
441441
return dest
442442
end
443-
function Base.unsafe_copyto!(::MTLDevice, dest::MtlArray{T,<:Any,Metal.SharedStorage}, doffs, src::MtlArray{T,<:Any,Metal.SharedStorage}, soffs, n) where T
444-
# these copies are implemented using pure memcpy's, not API calls, so aren't ordered.
443+
function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T, <:Any, Metal.SharedStorage}, doffs, src::MtlArray{T, <:Any, Metal.SharedStorage}, soffs, n) where {T}
445444
synchronize()
446-
GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array,dest), doffs), pointer(unsafe_wrap(Array,src), soffs), n)
445+
bytes = n * sizeof(T)
446+
# Use GPU blit for large copies (>32MiB) where it's faster than CPU memcpy.
447+
# For small copies, CPU memcpy avoids GPU command buffer overhead.
448+
if bytes >= 32 * 2^20 # If changed, also change in tests
449+
GC.@preserve src dest unsafe_copyto!(dev, pointer(dest, doffs), pointer(src, soffs), n)
450+
if Base.isbitsunion(T)
451+
error("Not implemented")
452+
end
453+
else
454+
GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array, dest), doffs), pointer(unsafe_wrap(Array, src), soffs), n)
455+
end
447456
return dest
448457
end
449458

test/array.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,33 @@ end
6969
end
7070

7171
@testset "copyto!" begin
72-
@testset "$T, $S" for S in [Metal.PrivateStorage, Metal.SharedStorage],
73-
T in [Float16, Float32, Bool, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8]
74-
dim = (1000,17,10)
75-
A = rand(T,dim)
76-
mtlA = mtl(A;storage=S)
77-
78-
#cpu -> gpu
79-
res = Metal.zeros(T,dim;storage=S)
80-
copyto!(res,A)
81-
@test Array(res) == Array(A)
82-
83-
#gpu -> cpu
84-
res = zeros(T,dim)
85-
copyto!(res,mtlA)
86-
@test Array(res) == Array(mtlA)
72+
@testset "$S" for S in [Metal.PrivateStorage, Metal.SharedStorage]
73+
@testset "$T" for T in [Float16, Float32, Bool, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8]
74+
dim = (1000, 17, 10)
75+
A = rand(T, dim)
76+
mtlA = mtl(A; storage = S)
77+
78+
#cpu -> gpu
79+
res = Metal.zeros(T, dim; storage = S)
80+
copyto!(res, A)
81+
@test Array(res) == Array(A)
82+
83+
#gpu -> cpu
84+
res = zeros(T, dim)
85+
copyto!(res, mtlA)
86+
@test Array(res) == Array(mtlA)
87+
88+
#gpu -> gpu
89+
res = Metal.zeros(T, dim; storage = S)
90+
copyto!(res, mtlA)
91+
@test Array(res) == Array(mtlA)
92+
end
8793

88-
#gpu -> gpu
89-
res = Metal.zeros(T,dim;storage=S)
90-
copyto!(res,mtlA)
94+
# Large array, only test Float32
95+
A = rand(Float32, 32 * 2^20)
96+
mtlA = mtl(A; storage = S)
97+
res = similar(A)
98+
copyto!(res, mtlA)
9199
@test Array(res) == Array(mtlA)
92100
end
93101
end

0 commit comments

Comments
 (0)