Skip to content

Commit 42ae9b9

Browse files
authored
solve type stability of DiskIndex (#267)
* solve type stability of DiskIndex * assume_effects not needed * Aqua passing * fix a typo
1 parent 72db1f1 commit 42ae9b9

File tree

2 files changed

+91
-20
lines changed

2 files changed

+91
-20
lines changed

src/diskindex.jl

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ struct DiskIndex{N,M,A<:Tuple,B<:Tuple,C<:Tuple}
2929
data_indices::C
3030
end
3131
function DiskIndex(
32-
output_size::NTuple{N,<:Integer},
33-
temparray_size::NTuple{M,<:Integer},
32+
output_size::Tuple{Vararg{Integer}},
33+
temparray_size::Tuple{Vararg{Integer}},
3434
output_indices::Tuple,
3535
temparray_indices::Tuple,
3636
data_indices::Tuple
37-
) where {N,M}
38-
DiskIndex(Int.(output_size), Int.(temparray_size), output_indices, temparray_indices, data_indices)
37+
)
38+
output_size_int = map(Int, output_size)
39+
temparray_size_int = map(Int, temparray_size)
40+
DiskIndex(output_size_int, temparray_size_int, output_indices, temparray_indices, data_indices)
3941
end
4042
DiskIndex(a, i) = DiskIndex(a, i, batchstrategy(a))
4143
DiskIndex(a, i, batch_strategy) =
@@ -54,9 +56,41 @@ function _resolve_indices(chunks, i, indices_pre::DiskIndex, strategy::BatchStra
5456
indices_new, chunksrem = process_index(inow, chunks, strategy)
5557
_resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
5658
end
59+
# Some (pretty stupid) hacks to get around Base recursion limiting https://github.com/JuliaLang/julia/pull/48059
60+
# TODO: We can remove these if Base sorts this out.
61+
# This makes 3 arg type stable
62+
function _resolve_indices(chunks::Tuple{<:Any}, i::Tuple{<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy)
63+
inow = first(i)
64+
indices_new, chunksrem = process_index(inow, chunks, strategy)
65+
return merge_index(indices_pre, indices_new)
66+
end
67+
# This makes 4 arg type stable
68+
function _resolve_indices(chunks::Tuple{<:Any,<:Any}, i::Tuple{<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy)
69+
inow = first(i)
70+
indices_new, chunksrem = process_index(inow, chunks, strategy)
71+
return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
72+
end
73+
# This makes 5 arg type stable
74+
function _resolve_indices(chunks::Tuple{<:Any,<:Any,<:Any}, i::Tuple{<:Any,<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy)
75+
inow = first(i)
76+
indices_new, chunksrem = process_index(inow, chunks, strategy)
77+
return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
78+
end
79+
# This makes 6 arg type stable
80+
function _resolve_indices(chunks::Tuple{<:Any,<:Any,<:Any,<:Any}, i::Tuple{<:Any,<:Any,<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy)
81+
inow = first(i)
82+
indices_new, chunksrem = process_index(inow, chunks, strategy)
83+
return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
84+
end
5785
# Splat out CartesianIndex as regular indices
5886
function _resolve_indices(
59-
chunks, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
87+
chunks::Tuple, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
88+
)
89+
_resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
90+
end
91+
# This method is needed to resolve ambiguity
92+
function _resolve_indices(
93+
chunks::Tuple{<:Any}, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
6094
)
6195
_resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
6296
end
@@ -112,33 +146,48 @@ Calculate indices for `i` the first chunk/s in `chunks`
112146
Returns a [`DiskIndex`](@ref), and the remaining chunks.
113147
"""
114148
process_index(i, chunks, ::NoBatch) = process_index(i, chunks)
115-
function process_index(i::CartesianIndex{N}, chunks, ::NoBatch) where {N}
149+
function process_index(i::CartesianIndex{N}, chunks::Tuple, ::NoBatch) where {N}
116150
_, chunksrem = splitchunks(i, chunks)
117151
di = DiskIndex((), map(one, i.I), (), (1,), map(i -> i:i, i.I))
152+
118153
return di, chunksrem
119154
end
120155
process_index(inow::Integer, chunks) =
121156
DiskIndex((), (1,), (), (1,), (inow:inow,)), tail(chunks)
122157
function process_index(::Colon, chunks)
123158
s = arraysize_from_chunksize(first(chunks))
124-
DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), tail(chunks)
159+
di = DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),)
160+
return di, tail(chunks)
125161
end
126162
function process_index(i::AbstractUnitRange{<:Integer}, chunks, ::NoBatch)
127-
DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), tail(chunks)
163+
di = DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,))
164+
return di::DiskIndex, tail(chunks)::Tuple
128165
end
129166
function process_index(i::AbstractArray{<:Integer}, chunks, ::NoBatch)
130167
indmin, indmax = isempty(i) ? (1, 0) : extrema(i)
131-
di = DiskIndex(size(i), ((indmax - indmin + 1),), map(_ -> Colon(), size(i)), ((i .- (indmin - 1)),), (indmin:indmax,))
168+
169+
output_size = size(i)
170+
temparray_size = ((indmax - indmin + 1),)
171+
output_indices = map(_ -> Colon(), size(i))
172+
temparray_indices = ((i .- (indmin - 1)),)
173+
data_indices = (indmin:indmax,)
174+
di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices)
175+
132176
return di, tail(chunks)
133177
end
134178
function process_index(i::AbstractArray{Bool,N}, chunks, ::NoBatch) where {N}
135179
chunksnow, chunksrem = splitchunks(i, chunks)
136180
s = arraysize_from_chunksize.(chunksnow)
137181
cindmin, cindmax = extrema(view(CartesianIndices(s), i))
138182
indmin, indmax = cindmin.I, cindmax.I
139-
tempsize = indmax .- indmin .+ 1
140-
tempinds = view(i, range.(indmin, indmax)...)
141-
di = DiskIndex((sum(i),), tempsize, (Colon(),), (tempinds,), range.(indmin, indmax))
183+
184+
output_size = (sum(i),)
185+
temparray_size = map((max, min) -> max - min + 1, indmax, indmin)
186+
output_indices = (Colon(),)
187+
temparray_indices = (view(i, map(range, indmin, indmax)...),)
188+
data_indices = map(range, indmin, indmax)
189+
di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices)
190+
142191
return di, chunksrem
143192
end
144193
function process_index(i::AbstractArray{<:CartesianIndex{N}}, chunks, ::NoBatch) where {N}
@@ -151,17 +200,26 @@ function process_index(i::AbstractArray{<:CartesianIndex{N}}, chunks, ::NoBatch)
151200
extrema(v)
152201
end
153202
indmin, indmax = cindmin.I, cindmax.I
154-
tempsize = indmax .- indmin .+ 1
155-
tempoffset = cindmin - oneunit(cindmin)
156-
tempinds = i .- (CartesianIndex(tempoffset),)
157-
outinds = map(_ -> Colon(), size(i))
158-
di = DiskIndex(size(i), tempsize, outinds, (tempinds,), range.(indmin, indmax))
203+
204+
output_size = size(i)
205+
temparray_size = map((max, min) -> max - min + 1, indmax, indmin)
206+
temparray_offset = cindmin - oneunit(cindmin)
207+
temparray_indices = (i .- (CartesianIndex(temparray_offset),),)
208+
output_indices = map(_ -> Colon(), size(i))
209+
data_indices = map(range, indmin, indmax)
210+
di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices)
211+
159212
return di, chunksrem
160213
end
161214
function process_index(i::CartesianIndices{N}, chunks, ::NoBatch) where {N}
162215
_, chunksrem = splitchunks(i, chunks)
163-
cols = map(_ -> Colon(), i.indices)
164-
di = DiskIndex(length.(i.indices), length.(i.indices), cols, cols, i.indices)
216+
217+
output_size = map(length, i.indices)
218+
temparray_size = map(length, i.indices)
219+
output_indices = temparray_indices = map(_ -> Colon(), i.indices)
220+
data_indices = i.indices
221+
di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices)
222+
165223
return di, chunksrem
166224
end
167225

test/runtests.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using TraceFuns, Suppressor
1111
# using JET
1212
# JET.report_package(DiskArrays)
1313

14-
if VERSION >= v"1.9.0"
14+
@testset "Aqua.jl" begin
1515
Aqua.test_ambiguities([DiskArrays, Base, Core])
1616
Aqua.test_unbound_args(DiskArrays)
1717
Aqua.test_stale_deps(DiskArrays)
@@ -1103,3 +1103,16 @@ end
11031103
@test length(unique(a)) == length(unique(identity, a)) == 8
11041104
@test unique(x->x>3, a) == [1,4]
11051105
end
1106+
1107+
@testset "type stable DiskIndex" begin
1108+
a = AccessCountDiskArray(reshape(1:96, 2, 3, 4, 2, 2, 1), chunksize=(2, 2, 2, 2, 2, 1))
1109+
a_view3 = @view a[:, 1:2, 2:4, 1, 1, 1]
1110+
a_view4 = @view a[:, 1:2, 2:4, :, 1, 1]
1111+
a_view5 = @view a[:, 1:2, 2:4, :, :, 1]
1112+
a_view6 = @view a[:, 1:2, 2:4, :, :, :]
1113+
1114+
@inferred DiskArrays.DiskIndex(a_view3, (1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex
1115+
@inferred DiskArrays.DiskIndex(a_view4, (1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex
1116+
@inferred DiskArrays.DiskIndex(a_view5, (1:1, 1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex
1117+
@inferred DiskArrays.DiskIndex(a_view6, (1:1, 1:1, 1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex
1118+
end

0 commit comments

Comments
 (0)