Skip to content

Commit a7721c3

Browse files
committed
copyto! for heterogeneous arrays.
1 parent bd10409 commit a7721c3

File tree

2 files changed

+46
-30
lines changed

2 files changed

+46
-30
lines changed

src/abstractarray.jl

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ for (W, ctor) in (:AT => (A,mut)->mut(A), Adapt.wrappers...)
7070
end
7171
end
7272

73-
# memory operations
7473

75-
## basic copy methods that dispatch to copyto! for linear copies
74+
# memory copying
75+
76+
## basic linear copies of identically-typed memory
7677

7778
materialize(x::AbstractArray) = Array(x)
7879
materialize(x::GPUArray) = x
@@ -106,8 +107,10 @@ for (D, S) in ((GPUArray, AbstractArray), (Array, GPUArray), (GPUArray, GPUArray
106107
end
107108
end
108109

110+
## generalized blocks of heterogeneous memory
109111

110-
## higher-dimensional copy methods that dispatch to a kernel
112+
Base.copyto!(dest::GPUArray, src::GPUArray) =
113+
copyto!(dest, CartesianIndices(dest), src, CartesianIndices(src))
111114

112115
function copy_kernel!(state, dest, dest_offsets, src, src_offsets, shape, shape_dest, shape_source, length)
113116
i = linear_index(state)
@@ -121,10 +124,8 @@ function copy_kernel!(state, dest, dest_offsets, src, src_offsets, shape, shape_
121124
return
122125
end
123126

124-
function Base.copyto!(
125-
dest::GPUArray{T, N}, destcrange::CartesianIndices{N},
126-
src::GPUArray{T, N}, srccrange::CartesianIndices{N}
127-
) where {T, N}
127+
function Base.copyto!(dest::GPUArray{T, N}, destcrange::CartesianIndices{N},
128+
src::GPUArray{U, N}, srccrange::CartesianIndices{N}) where {T, U, N}
128129
shape = size(destcrange)
129130
if shape != size(srccrange)
130131
throw(DimensionMismatch("Ranges don't match their size. Found: $shape, $(size(srccrange))"))
@@ -133,20 +134,14 @@ function Base.copyto!(
133134

134135
dest_offsets = first.(destcrange.indices) .- 1
135136
src_offsets = first.(srccrange.indices) .- 1
136-
ui_shape = shape
137-
gpu_call(
138-
copy_kernel!, dest,
139-
(dest, dest_offsets, src, src_offsets, ui_shape, size(dest), size(src), len),
140-
len
141-
)
137+
gpu_call(copy_kernel!, dest,
138+
(dest, dest_offsets, src, src_offsets, shape, size(dest), size(src), len),
139+
len)
142140
dest
143141
end
144142

145-
146-
function Base.copyto!(
147-
dest::GPUArray{T, N}, destcrange::CartesianIndices{N},
148-
src::AbstractArray{T, N}, srccrange::CartesianIndices{N}
149-
) where {T, N}
143+
function Base.copyto!(dest::GPUArray{T, N}, destcrange::CartesianIndices{N},
144+
src::AbstractArray{T, N}, srccrange::CartesianIndices{N}) where {T, N}
150145
# Is this efficient? Maybe!
151146
# TODO: compare to a pure intrinsic copyto implementation!
152147
# this would mean looping over linear sections of memory and
@@ -157,11 +152,8 @@ function Base.copyto!(
157152
dest
158153
end
159154

160-
161-
function Base.copyto!(
162-
dest::AbstractArray{T, N}, destcrange::CartesianIndices{N},
163-
src::GPUArray{T, N}, srccrange::CartesianIndices{N}
164-
) where {T, N}
155+
function Base.copyto!(dest::AbstractArray{T, N}, destcrange::CartesianIndices{N},
156+
src::GPUArray{T, N}, srccrange::CartesianIndices{N}) where {T, N}
165157
# Is this efficient? Maybe!
166158
dest_gpu = similar(src, size(destcrange))
167159
nrange = CartesianIndices(size(dest_gpu))
@@ -170,21 +162,38 @@ function Base.copyto!(
170162
dest
171163
end
172164

165+
## other
166+
173167
Base.copy(x::GPUArray) = identity.(x)
168+
174169
Base.deepcopy(x::GPUArray) = copy(x)
175170

171+
172+
# reinterpret
173+
176174
#=
177-
reinterpret taken from julia base/array.jl
175+
copied from julia base/array.jl
178176
Copyright (c) 2009-2016: Jeff Bezanson, Stefan Karpinski, Viral B. Shah, and other contributors:
179177
180178
https://github.com/JuliaLang/julia/contributors
181179
182-
Permission is hereby granted, free of charge, to any person obtaining a copie of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
183-
184-
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
185-
186-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
180+
Permission is hereby granted, free of charge, to any person obtaining a copie of this
181+
software and associated documentation files (the "Software"), to deal in the Software
182+
without restriction, including without limitation the rights to use, copy, modify, merge,
183+
publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
184+
to whom the Software is furnished to do so, subject to the following conditions:
185+
186+
The above copyright notice and this permission notice shall be included in all copies or
187+
substantial portions of the Software.
188+
189+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
190+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
191+
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
192+
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
193+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
194+
DEALINGS IN THE SOFTWARE.
187195
=#
196+
188197
import Base.reinterpret
189198

190199
"""

src/testsuite/base.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function test_base(AT)
4949
fill!(a, 0f0)
5050
copyto!(a, r1, y, r2)
5151
@test Array(a) == x
52-
52+
5353
x = fill(0f0, (10,))
5454
y = rand(Float32, (20,))
5555
a = AT(x)
@@ -71,6 +71,13 @@ function test_base(AT)
7171
copyto!(x, r1[1], y, r2[1])
7272
copyto!(a, r1, b, r2)
7373
@test x == Array(a)
74+
75+
x = fill(0., (10,))
76+
y = fill(1, (10,))
77+
a = AT(x)
78+
b = AT(y)
79+
copyto!(a, b)
80+
@test Float64.(y) == Array(a)
7481
end
7582

7683
@testset "vcat + hcat" begin

0 commit comments

Comments
 (0)