diff --git a/src/diskindex.jl b/src/diskindex.jl index 4005c95..a391e60 100644 --- a/src/diskindex.jl +++ b/src/diskindex.jl @@ -29,13 +29,15 @@ struct DiskIndex{N,M,A<:Tuple,B<:Tuple,C<:Tuple} data_indices::C end function DiskIndex( - output_size::NTuple{N,<:Integer}, - temparray_size::NTuple{M,<:Integer}, + output_size::Tuple{Vararg{Integer}}, + temparray_size::Tuple{Vararg{Integer}}, output_indices::Tuple, temparray_indices::Tuple, data_indices::Tuple -) where {N,M} - DiskIndex(Int.(output_size), Int.(temparray_size), output_indices, temparray_indices, data_indices) +) + output_size_int = map(Int, output_size) + temparray_size_int = map(Int, temparray_size) + DiskIndex(output_size_int, temparray_size_int, output_indices, temparray_indices, data_indices) end DiskIndex(a, i) = DiskIndex(a, i, batchstrategy(a)) DiskIndex(a, i, batch_strategy) = @@ -54,9 +56,41 @@ function _resolve_indices(chunks, i, indices_pre::DiskIndex, strategy::BatchStra indices_new, chunksrem = process_index(inow, chunks, strategy) _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy) end +# Some (pretty stupid) hacks to get around Base recursion limiting https://github.com/JuliaLang/julia/pull/48059 +# TODO: We can remove these if Base sorts this out. +# This makes 3 arg type stable +function _resolve_indices(chunks::Tuple{<:Any}, i::Tuple{<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy) + inow = first(i) + indices_new, chunksrem = process_index(inow, chunks, strategy) + return merge_index(indices_pre, indices_new) +end +# This makes 4 arg type stable +function _resolve_indices(chunks::Tuple{<:Any,<:Any}, i::Tuple{<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy) + inow = first(i) + indices_new, chunksrem = process_index(inow, chunks, strategy) + return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy) +end +# This makes 5 arg type stable +function _resolve_indices(chunks::Tuple{<:Any,<:Any,<:Any}, i::Tuple{<:Any,<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy) + inow = first(i) + indices_new, chunksrem = process_index(inow, chunks, strategy) + return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy) +end +# This makes 6 arg type stable +function _resolve_indices(chunks::Tuple{<:Any,<:Any,<:Any,<:Any}, i::Tuple{<:Any,<:Any,<:Any,<:Any}, indices_pre::DiskIndex, strategy::BatchStrategy) + inow = first(i) + indices_new, chunksrem = process_index(inow, chunks, strategy) + return _resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy) +end # Splat out CartesianIndex as regular indices function _resolve_indices( - chunks, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy + chunks::Tuple, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy +) + _resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy) +end +# This method is needed to resolve ambiguity +function _resolve_indices( + chunks::Tuple{<:Any}, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy ) _resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy) end @@ -112,23 +146,33 @@ Calculate indices for `i` the first chunk/s in `chunks` Returns a [`DiskIndex`](@ref), and the remaining chunks. """ process_index(i, chunks, ::NoBatch) = process_index(i, chunks) -function process_index(i::CartesianIndex{N}, chunks, ::NoBatch) where {N} +function process_index(i::CartesianIndex{N}, chunks::Tuple, ::NoBatch) where {N} _, chunksrem = splitchunks(i, chunks) di = DiskIndex((), map(one, i.I), (), (1,), map(i -> i:i, i.I)) + return di, chunksrem end process_index(inow::Integer, chunks) = DiskIndex((), (1,), (), (1,), (inow:inow,)), tail(chunks) function process_index(::Colon, chunks) s = arraysize_from_chunksize(first(chunks)) - DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), tail(chunks) + di = DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),) + return di, tail(chunks) end function process_index(i::AbstractUnitRange{<:Integer}, chunks, ::NoBatch) - DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), tail(chunks) + di = DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)) + return di::DiskIndex, tail(chunks)::Tuple end function process_index(i::AbstractArray{<:Integer}, chunks, ::NoBatch) indmin, indmax = isempty(i) ? (1, 0) : extrema(i) - di = DiskIndex(size(i), ((indmax - indmin + 1),), map(_ -> Colon(), size(i)), ((i .- (indmin - 1)),), (indmin:indmax,)) + + output_size = size(i) + temparray_size = ((indmax - indmin + 1),) + output_indices = map(_ -> Colon(), size(i)) + temparray_indices = ((i .- (indmin - 1)),) + data_indices = (indmin:indmax,) + di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices) + return di, tail(chunks) end function process_index(i::AbstractArray{Bool,N}, chunks, ::NoBatch) where {N} @@ -136,9 +180,14 @@ function process_index(i::AbstractArray{Bool,N}, chunks, ::NoBatch) where {N} s = arraysize_from_chunksize.(chunksnow) cindmin, cindmax = extrema(view(CartesianIndices(s), i)) indmin, indmax = cindmin.I, cindmax.I - tempsize = indmax .- indmin .+ 1 - tempinds = view(i, range.(indmin, indmax)...) - di = DiskIndex((sum(i),), tempsize, (Colon(),), (tempinds,), range.(indmin, indmax)) + + output_size = (sum(i),) + temparray_size = map((max, min) -> max - min + 1, indmax, indmin) + output_indices = (Colon(),) + temparray_indices = (view(i, map(range, indmin, indmax)...),) + data_indices = map(range, indmin, indmax) + di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices) + return di, chunksrem end function process_index(i::AbstractArray{<:CartesianIndex{N}}, chunks, ::NoBatch) where {N} @@ -151,17 +200,26 @@ function process_index(i::AbstractArray{<:CartesianIndex{N}}, chunks, ::NoBatch) extrema(v) end indmin, indmax = cindmin.I, cindmax.I - tempsize = indmax .- indmin .+ 1 - tempoffset = cindmin - oneunit(cindmin) - tempinds = i .- (CartesianIndex(tempoffset),) - outinds = map(_ -> Colon(), size(i)) - di = DiskIndex(size(i), tempsize, outinds, (tempinds,), range.(indmin, indmax)) + + output_size = size(i) + temparray_size = map((max, min) -> max - min + 1, indmax, indmin) + temparray_offset = cindmin - oneunit(cindmin) + temparray_indices = (i .- (CartesianIndex(temparray_offset),),) + output_indices = map(_ -> Colon(), size(i)) + data_indices = map(range, indmin, indmax) + di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices) + return di, chunksrem end function process_index(i::CartesianIndices{N}, chunks, ::NoBatch) where {N} _, chunksrem = splitchunks(i, chunks) - cols = map(_ -> Colon(), i.indices) - di = DiskIndex(length.(i.indices), length.(i.indices), cols, cols, i.indices) + + output_size = map(length, i.indices) + temparray_size = map(length, i.indices) + output_indices = temparray_indices = map(_ -> Colon(), i.indices) + data_indices = i.indices + di = DiskIndex(output_size, temparray_size, output_indices, temparray_indices, data_indices) + return di, chunksrem end diff --git a/test/runtests.jl b/test/runtests.jl index 717e2fb..38fac23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using TraceFuns, Suppressor # using JET # JET.report_package(DiskArrays) -if VERSION >= v"1.9.0" +@testset "Aqua.jl" begin Aqua.test_ambiguities([DiskArrays, Base, Core]) Aqua.test_unbound_args(DiskArrays) Aqua.test_stale_deps(DiskArrays) @@ -1087,3 +1087,16 @@ end @test length(unique(a)) == length(unique(identity, a)) == 8 @test unique(x->x>3, a) == [1,4] end + +@testset "type stable DiskIndex" begin + a = AccessCountDiskArray(reshape(1:96, 2, 3, 4, 2, 2, 1), chunksize=(2, 2, 2, 2, 2, 1)) + a_view3 = @view a[:, 1:2, 2:4, 1, 1, 1] + a_view4 = @view a[:, 1:2, 2:4, :, 1, 1] + a_view5 = @view a[:, 1:2, 2:4, :, :, 1] + a_view6 = @view a[:, 1:2, 2:4, :, :, :] + + @inferred DiskArrays.DiskIndex(a_view3, (1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex + @inferred DiskArrays.DiskIndex(a_view4, (1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex + @inferred DiskArrays.DiskIndex(a_view5, (1:1, 1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex + @inferred DiskArrays.DiskIndex(a_view6, (1:1, 1:1, 1:1, 1:1, 1:1, 1:1), DiskArrays.NoBatch()) #DiskArrays.DiskIndex +end