Skip to content

Commit e211986

Browse files
felixcremermeggart
andauthored
Check for zero length index in copyto! and same length of the indices (#169)
* Check for zero length index in copyto! and same length of the indices The check for the same length of the indices is copied from the behaviour of copyto! on normal arrays. * Remove println * Replace any(==(0)) with isempty Co-authored-by: Fabian Gans <[email protected]> --------- Co-authored-by: Fabian Gans <[email protected]>
1 parent 309f613 commit e211986

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/array.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,18 @@ function _copyto!(dest::AbstractArray, source::AbstractArray)
6161
reshape(dest, size(source)) .= source
6262
return dest
6363
end
64-
_copyto!(dest, Rdest, src, Rsrc) = view(dest, Rdest) .= view(src, Rsrc)
6564

65+
function _copyto!(dest, Rdest, src, Rsrc)
66+
if size(Rdest) != size(Rsrc)
67+
throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
68+
end
69+
70+
if isempty(Rdest)
71+
# This check is here to catch #168
72+
return dest
73+
end
74+
view(dest, Rdest) .= view(src, Rsrc)
75+
end
6676
# Use a view for lazy reverse
6777
_reverse(a, ::Colon) = _reverse(a, ntuple(identity, ndims(a)))
6878
_reverse(a, dims::Int) = _reverse(a, (dims,))

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,12 @@ end
663663
copyto!(x, a_disk)
664664
@test x == a
665665
copyto!(x, CartesianIndices((1:3, 1:2)), a_disk, CartesianIndices((8:10, 8:9)))
666+
# Test copyto! with zero length index
667+
x_empty = Matrix{Int64}(undef, 0,2)
668+
copyto!(x_empty, CartesianIndices((1:0, 1:2)), a_disk, CartesianIndices((8:7, 8:9)))
669+
# copyto! with different length should throw an error
670+
@test_throws ArgumentError copyto!(x, CartesianIndices((1:1, 1:2)), a_disk, CartesianIndices((4:6, 8:9)))
671+
666672
end
667673

668674
@test collect(reverse(a_disk)) == reverse(a)

0 commit comments

Comments
 (0)