Skip to content

Commit b0a4a7b

Browse files
rafaqzmeggart
andauthored
fix and test mixed CartesianIndex (#260)
* fix and test mixed CartesianIndex * fix view and test * move code out of macros and copy base julia view * bugfix view * esc * revert splitchunks (broken) * Re-fix bug in splitchunks --------- Co-authored-by: meggart <[email protected]>
1 parent 80274cd commit b0a4a7b

File tree

5 files changed

+63
-22
lines changed

5 files changed

+63
-22
lines changed

src/DiskArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import ConstructionBase
44

55
using LRUCache: LRUCache, LRU
66

7+
using Base: tail
8+
79
# Use the README as the module docs
810
@doc let
911
path = joinpath(dirname(@__DIR__), "README.md")

src/diskindex.jl

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ DiskIndex(a, i::Tuple{<:AbstractVector{<:Integer}}, batchstrategy) =
5252
function _resolve_indices(chunks, i, indices_pre::DiskIndex, strategy::BatchStrategy)
5353
inow = first(i)
5454
indices_new, chunksrem = process_index(inow, chunks, strategy)
55-
_resolve_indices(chunksrem, Base.tail(i), merge_index(indices_pre, indices_new), strategy)
55+
_resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
56+
end
57+
# Splat out CartesianIndex as regular indices
58+
function _resolve_indices(
59+
chunks, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
60+
)
61+
_resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
5662
end
5763
_resolve_indices(::Tuple{}, ::Tuple{}, indices::DiskIndex, strategy::BatchStrategy) = indices
5864
# No dimension left in array, only singular indices allowed
@@ -61,17 +67,25 @@ function _resolve_indices(::Tuple{}, i, indices_pre::DiskIndex, strategy::BatchS
6167
(length(inow) == 1 && only(inow) == 1) || throw(ArgumentError("Trailing indices must be 1"))
6268
indices_new = DiskIndex(size(inow), (), size(inow), (), ())
6369
indices = merge_index(indices_pre, indices_new)
64-
_resolve_indices((), Base.tail(i), indices, strategy)
70+
_resolve_indices((), tail(i), indices, strategy)
71+
end
72+
# Splat out CartesianIndex as regular trailing indices
73+
function _resolve_indices(
74+
::Tuple{}, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
75+
)
76+
_resolve_indices((), (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
6577
end
6678
# Still dimensions left, but no indices available
6779
function _resolve_indices(chunks, ::Tuple{}, indices_pre::DiskIndex, strategy::BatchStrategy)
6880
chunksnow = first(chunks)
69-
arraysize_from_chunksize(chunksnow) == 1 || throw(ArgumentError("Indices can only be omitted for trailing singleton dimensions"))
81+
checktrailing(arraysize_from_chunksize(chunksnow))
7082
indices_new = add_dimension_index(strategy)
7183
indices = merge_index(indices_pre, indices_new)
72-
_resolve_indices(Base.tail(chunks), (), indices, strategy)
84+
_resolve_indices(tail(chunks), (), indices, strategy)
7385
end
7486

87+
checktrailing(i) = i == 1 || throw(ArgumentError("Indices can only be omitted for trailing singleton dimensions"))
88+
7589
add_dimension_index(::NoBatch) = DiskIndex((), (1,), (), (1,), (1:1,))
7690
add_dimension_index(::Union{ChunkRead,SubRanges}) = DiskIndex((), (1,), ([()],), ([(1,)],), ([(1:1,)],))
7791

@@ -98,18 +112,24 @@ Calculate indices for `i` the first chunk/s in `chunks`
98112
Returns a [`DiskIndex`](@ref), and the remaining chunks.
99113
"""
100114
process_index(i, chunks, ::NoBatch) = process_index(i, chunks)
101-
process_index(inow::Integer, chunks) = DiskIndex((), (1,), (), (1,), (inow:inow,)), Base.tail(chunks)
115+
function process_index(i::CartesianIndex{N}, chunks, ::NoBatch) where {N}
116+
_, chunksrem = splitchunks(i, chunks)
117+
di = DiskIndex((), map(one, i.I), (), (1,), map(i -> i:i, i.I))
118+
return di, chunksrem
119+
end
120+
process_index(inow::Integer, chunks) =
121+
DiskIndex((), (1,), (), (1,), (inow:inow,)), tail(chunks)
102122
function process_index(::Colon, chunks)
103123
s = arraysize_from_chunksize(first(chunks))
104-
DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), Base.tail(chunks)
124+
DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), tail(chunks)
105125
end
106126
function process_index(i::AbstractUnitRange{<:Integer}, chunks, ::NoBatch)
107-
DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), Base.tail(chunks)
127+
DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), tail(chunks)
108128
end
109129
function process_index(i::AbstractArray{<:Integer}, chunks, ::NoBatch)
110130
indmin, indmax = isempty(i) ? (1, 0) : extrema(i)
111131
di = DiskIndex(size(i), ((indmax - indmin + 1),), map(_ -> Colon(), size(i)), ((i .- (indmin - 1)),), (indmin:indmax,))
112-
return di, Base.tail(chunks)
132+
return di, tail(chunks)
113133
end
114134
function process_index(i::AbstractArray{Bool,N}, chunks, ::NoBatch) where {N}
115135
chunksnow, chunksrem = splitchunks(i, chunks)
@@ -162,7 +182,12 @@ splitchunks(i::CartesianIndex, chunks) = splitchunks(i.I, (), chunks)
162182
splitchunks(_, chunks) = (first(chunks),), Base.tail(chunks)
163183
splitchunks(si, chunksnow, chunksrem) =
164184
splitchunks(Base.tail(si), (chunksnow..., first(chunksrem)), Base.tail(chunksrem))
185+
function splitchunks(si,chunksnow, ::Tuple{})
186+
only(first(si)) == 1 || throw(ArgumentError("Trailing indices must be 1"))
187+
splitchunks(Base.tail(si), chunksnow, ())
188+
end
165189
splitchunks(::Tuple{}, chunksnow, chunksrem) = (chunksnow, chunksrem)
190+
splitchunks(::Tuple{}, chunksnow, chunksrem::Tuple{}) = (chunksnow, chunksrem)
166191

167192
"""
168193
output_aliasing(di::DiskIndex, ndims_dest, ndims_source)

src/indexing.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ function getindex_disk(a::AbstractArray, i::Integer)
5656
return only(outputarray)
5757
end
5858
getindex_disk(a::AbstractArray, i...) = getindex_disk!(nothing, a, i...)
59+
getindex_disk(a::AbstractArray, i::ChunkIndex{<:Any,OneBasedChunks}) =
60+
a[eachchunk(a)[i.I]...]
61+
getindex_disk(a::AbstractArray, i::ChunkIndex{<:Any,OffsetChunks}) =
62+
wrapchunk(a[nooffset(i)], eachchunk(a)[i.I])
5963

6064
function getindex_disk!(out::Union{Nothing,AbstractArray}, a::AbstractArray, i...)
6165
# Check if we can write once or need to use multiple batches
@@ -202,7 +206,7 @@ end
202206
Generate an `Array` to pass to `readblock!`
203207
"""
204208
function create_outputarray(out::AbstractArray, a::AbstractArray, output_size::Tuple)
205-
size(out) == output_size || throw(ArgumentError("Expected output array size of $output_size"))
209+
size(out) == output_size || throw(ArgumentError("Expected output array size of $output_size, got $(size(out))"))
206210
return out
207211
end
208212
create_outputarray(::Nothing, a::AbstractArray, output_size::Tuple) =
@@ -306,10 +310,6 @@ macro implement_getindex(t)
306310
quote
307311
DiskArrays.isdisk(::Type{<:$t}) = true
308312
Base.getindex(a::$t, i...) = getindex_disk(a, i...)
309-
@inline Base.getindex(a::$t, i::ChunkIndex{<:Any,OneBasedChunks}) =
310-
a[eachchunk(a)[i.I]...]
311-
@inline Base.getindex(a::$t, i::ChunkIndex{<:Any,OffsetChunks}) =
312-
wrapchunk(a[nooffset(i)], eachchunk(a)[i.I])
313313
function DiskArrays.ChunkIndices(a::$t; offset=false)
314314
return ChunkIndices(
315315
map(s -> 1:s, size(eachchunk(a))), offset ? OffsetChunks() : OneBasedChunks()

src/subarray.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,20 @@ function eachchunk_view(::Chunked, vv)
4444
end
4545
eachchunk_view(::Unchunked, a) = estimate_chunksize(a)
4646

47-
# Implementaion macro
47+
function view_disk(A, I...)
48+
@inline
49+
# Modified from Base.view
50+
J = to_indices(A, I)
51+
@boundscheck checkbounds(A, J...)
52+
J′ = Base.rm_singleton_indices(ntuple(Returns(true), Val(ndims(A))), J...)
53+
SubDiskArray(Base.unsafe_view(A, J′...))
54+
end
4855

56+
# Implementaion macro
4957
macro implement_subarray(t)
5058
t = esc(t)
5159
quote
52-
function Base.view(a::$t, i...)
53-
i2 = _replace_colon.(size(a), i)
54-
return SubDiskArray(SubArray(a, i2))
55-
end
56-
Base.view(a::$t, i::CartesianIndices) = view(a, i.indices...)
60+
@inline Base.view(a::$t, i...) = view_disk(a, i...)
5761
Base.vec(a::$t) = view(a, :)
5862
end
5963
end

test/runtests.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ end
3838
@test a[CartesianIndex(1, 2), 3] == 15
3939
@test a[CartesianIndex(1, 2, 3)] == 15
4040
end
41+
4142
@testset "isdisk" begin
4243
a = reshape(1:24, 2, 3, 4)
4344
da = AccessCountDiskArray(a; chunksize=(2, 2, 2))
@@ -52,16 +53,22 @@ end
5253

5354
function test_getindex(a)
5455
@test a[2, 3, 1] == 10
56+
@test a[CartesianIndex(2, 3), 1] == 10
57+
@test a[2, CartesianIndex(3,), 1] == 10
58+
@test a[CartesianIndex(2, 3, 1)] == 10
59+
@test a[1:2, CartesianIndex(3, 1, 1)] == 9:10
5560
@test a[2, 3] == 10
61+
@test a[CartesianIndex(2, 3)] == 10
5662
@test a[2, 3, 1, 1] == 10
5763
@test a[:, 1] == [1, 2, 3, 4]
5864
@test a[1:2, 1:2, 1, 1] == [1 5; 2 6]
5965
@test a[end:-1:1, 1, 1] == [4, 3, 2, 1]
6066
@test a[2, 3, 1, 1:1] == [10]
6167
@test a[2, 3, 1, [1], [1]] == fill(10, 1, 1)
6268
@test a[:, 3, 1, [1]] == reshape(9:12, 4, 1)
69+
@test a[:, CartesianIndex(3, 1), [1]] == reshape(9:12, 4, 1)
6370
@test a[CartesianIndices((1:2, 1:2)), 1] == [1 5; 2 6]
64-
@test getindex_count(a) == 10
71+
@test getindex_count(a) == 16
6572
# Test bitmask indexing
6673
m = falses(4, 5, 1)
6774
m[2, [1, 2, 3, 5], 1] .= true
@@ -74,6 +81,7 @@ function test_getindex(a)
7481
@test a[2:4:14] == [2, 6, 10, 14]
7582
# Test that readblock was called exactly onces for every getindex
7683
@test a[2:2:4, 1:2:5] == [2 10 18; 4 12 20]
84+
@test a[2:2:4, 1:2:5] == [2 10 18; 4 12 20]
7785
@test a[[1, 3, 4], [1, 3], 1] == [1 9; 3 11; 4 12]
7886
@testset "allowscalar" begin
7987
DiskArrays.allowscalar(false)
@@ -86,7 +94,7 @@ function test_getindex(a)
8694
end
8795

8896
function test_setindex(a)
89-
a[1, 1, 1] = 1
97+
a[CartesianIndex(1, 1), 1] = 1
9098
a[1, 2] = 2
9199
a[1, 3, 1, 1] = 3
92100
a[2:2, :] = [1, 2, 3, 4, 5]
@@ -118,10 +126,12 @@ function test_view(a)
118126
v[1:2, 1] = [1, 2]
119127
v[1:2, 2:3] = [4 4; 4 4]
120128
@test v[1:2, 1] == [1, 2]
129+
@test v[1:2, CartesianIndex(1,)] == [1, 2]
130+
@test v[1:2, CartesianIndex(1, 1)] == [1, 2]
121131
@test v[1:2, 2:3] == [4 4; 4 4]
122132
@test trueparent(a)[2:3, 2] == [1, 2]
123133
@test trueparent(a)[2:3, 3:4] == [4 4; 4 4]
124-
@test getindex_count(a) == 2
134+
@test getindex_count(a) == 4
125135
@test setindex_count(a) == 2
126136

127137
v2 = view(a, 2:3, 2:4, Int[])

0 commit comments

Comments
 (0)