Skip to content

Commit 17a8518

Browse files
authored
Cleanup indices info + better docs (#342)
Cleanup and document IndicesInfo Before making this public I wanted to ensure that we had control over the construction of `IndicesInfo` so it was guaranteed to be valid. Now it has inner struct constructors only. Also provided `ndims_index` and `ndims_shape` so all information can be accessed without directly looking at `IndicesInfo`'s parametric typing.
1 parent 9b74e3b commit 17a8518

File tree

5 files changed

+154
-73
lines changed

5 files changed

+154
-73
lines changed

docs/src/api.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ ArrayInterfaceCore.device
1212
ArrayInterfaceCore.defines_strides
1313
ArrayInterfaceCore.fast_matrix_colors
1414
ArrayInterfaceCore.fast_scalar_indexing
15+
ArrayInterfaceCore.indices_do_not_alias
16+
ArrayInterfaceCore.instances_do_not_alias
1517
ArrayInterfaceCore.is_forwarding_wrapper
1618
ArrayInterfaceCore.ismutable
1719
ArrayInterfaceCore.isstructured
@@ -84,6 +86,7 @@ ArrayInterface.from_parent_dims
8486
ArrayInterface.getindex
8587
ArrayInterface.indices
8688
ArrayInterface.insert
89+
ArrayInterface.length
8790
ArrayInterface.lazy_axes
8891
ArrayInterface.offset1
8992
ArrayInterface.offsets
@@ -106,8 +109,8 @@ ArrayInterface.BroadcastAxis
106109
ArrayInterface.LazyAxis
107110
ArrayInterface.OptionallyStaticStepRange
108111
ArrayInterface.OptionallyStaticUnitRange
109-
ArrayInteraface.SOneTo
110-
ArrayInteraface.SUnitRange
112+
ArrayInterface.SOneTo
113+
ArrayInterface.SUnitRange
111114
ArrayInterface.StrideIndex
112115
```
113116

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 122 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -806,74 +806,122 @@ ndims_shape(x) = ndims_shape(typeof(x))
806806
end
807807

808808
"""
809+
IndicesInfo{N}(inds::Tuple) -> IndicesInfo{N}(typeof(inds))
809810
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,pdims,cdims}()
811+
IndicesInfo(inds::Tuple) -> IndicesInfo(typeof(inds))
812+
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{maximum(pdims),pdims,cdims}()
810813
811-
Provides basic trait information for each index type in in the tuple `T`. `pdims` and
812-
`cdims` are dimension mappings to the parent and child dimensions respectively.
814+
815+
Maps a tuple of indices to `N` dimensions. The resulting `pdims` is a tuple where each
816+
field in `inds` (or field type in `T`) corresponds to the parent dimensions accessed.
817+
`cdims` similarly maps indices to the resulting child array produced after indexing with
818+
`inds`. If `N` is not provided then it is assumed that all indices are represented by parent
819+
dimensions and there are no trailing dimensions accessed. These may be accessed by through
820+
`parentdims(info::IndicesInfo)` and `childdims(info::IndicesInfo)`. If `N` is not provided,
821+
it is assumed that no indices are accessing trailing dimensions (which are represented as
822+
`0` in `parentdims(info)[index_position]`).
823+
824+
The the fields and types of `IndicesInfo` should not be accessed directly.
825+
Instead [`parentdims`](@ref), [`childdims`](@ref), [`ndims_index`](@ref), and
826+
[`ndims_shape`](@ref) should be used to extract relevant information.
813827
814828
# Examples
815829
816830
```julia
817-
julia> using ArrayInterfaceCore: IndicesInfo
831+
julia> using ArrayInterfaceCore: IndicesInfo, parentdims, childdims, ndims_index, ndims_shape
832+
833+
julia> info = IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
834+
835+
julia> parentdims(info) # the last two indices access trailing dimensions
836+
(1, (2, 3), 4, 5, 0, 0)
837+
838+
julia> childdims(info)
839+
(1, 2, 0, (3, 4), 5, 0)
840+
841+
julia> childdims(info)[3] # index 3 accesses a parent dimension but is dropped in the child array
842+
0
818843
819-
julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
820-
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
844+
julia> ndims_index(info)
845+
5
846+
847+
julia> ndims_shape(info)
848+
5
849+
850+
julia> info = IndicesInfo(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
851+
852+
julia> parentdims(info) # assumed no trailing dimensions
853+
(1, (2, 3), 4, 5, 6, 7)
854+
855+
julia> ndims_index(info) # assumed no trailing dimensions
856+
7
821857
822858
```
823859
"""
824-
struct IndicesInfo{N,NI,NS} end
825-
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
826-
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
827-
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
828-
end
829-
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
830-
_indices_info(
831-
Val{_find_first_true(map_tuple_type(is_splat_index, T))}(),
832-
IndicesInfo{N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)}()
833-
)
834-
end
835-
function _indices_info(::Val{nothing}, ::IndicesInfo{1,(1,),NS}) where {NS}
836-
ns1 = getfield(NS, 1)
837-
IndicesInfo{1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
838-
end
839-
function _indices_info(::Val{nothing}, ::IndicesInfo{N,(1,),NS}) where {N,NS}
840-
ns1 = getfield(NS, 1)
841-
IndicesInfo{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
842-
end
843-
@inline function _indices_info(::Val{nothing}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS}
844-
if sum(NI) > N
845-
IndicesInfo{N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)}()
846-
else
847-
IndicesInfo{N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)}()
860+
struct IndicesInfo{Np,pdims,cdims,Nc}
861+
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
862+
SI = _find_first_true(map_tuple_type(is_splat_index, T))
863+
NI = map_tuple_type(ndims_index, T)
864+
NS = map_tuple_type(ndims_shape, T)
865+
if SI === nothing
866+
ndi = NI
867+
nds = NS
868+
else
869+
nsplat = N - sum(NI)
870+
if nsplat === 0
871+
ndi = NI
872+
nds = NS
873+
else
874+
splatmul = max(0, nsplat + 1)
875+
ndi = _map_splats(splatmul, SI, NI)
876+
nds = _map_splats(splatmul, SI, NS)
877+
end
878+
end
879+
if ndi === (1,) && N !== 1
880+
ns1 = getfield(nds, 1)
881+
new{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,),ns1}()
882+
else
883+
nds_cumsum = cumsum(nds)
884+
if sum(ndi) > N
885+
init_pdims = _accum_dims(cumsum(ndi), ndi)
886+
pdims = ntuple(nfields(init_pdims)) do i
887+
dim_i = getfield(init_pdims, i)
888+
if dim_i isa Tuple
889+
ntuple(length(dim_i)) do j
890+
dim_i_j = getfield(dim_i, j)
891+
dim_i_j > N ? 0 : dim_i_j
892+
end
893+
else
894+
dim_i > N ? 0 : dim_i
895+
end
896+
end
897+
new{N, pdims, _accum_dims(nds_cumsum, nds), last(nds_cumsum)}()
898+
else
899+
new{N,_accum_dims(cumsum(ndi), ndi), _accum_dims(nds_cumsum, nds), last(nds_cumsum)}()
900+
end
901+
end
848902
end
849-
end
850-
@inline function _indices_info(::Val{SI}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS,SI}
851-
nsplat = N - sum(NI)
852-
if nsplat === 0
853-
_indices_info(Val{nothing}(), IndicesInfo{N,NI,NS}())
854-
else
855-
splatmul = max(0, nsplat + 1)
856-
_indices_info(Val{nothing}(), IndicesInfo{N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)}())
903+
IndicesInfo{N}(@nospecialize(t::Tuple)) where {N} = IndicesInfo{N}(typeof(t))
904+
function IndicesInfo(@nospecialize(T::Type{<:Tuple}))
905+
ndi = map_tuple_type(ndims_index, T)
906+
nds = map_tuple_type(ndims_shape, T)
907+
ndi_sum = cumsum(ndi)
908+
nds_sum = cumsum(nds)
909+
nf = nfields(ndi_sum)
910+
pdims = _accum_dims(ndi_sum, ndi)
911+
cdims = _accum_dims(nds_sum, nds)
912+
new{getfield(ndi_sum, nf),pdims,cdims,getfield(nds_sum, nf)}()
857913
end
914+
IndicesInfo(@nospecialize t::Tuple) = IndicesInfo(typeof(t))
915+
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
916+
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
917+
end
918+
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
858919
end
859920
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})
860921
ntuple(length(dims)) do i
861922
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
862923
end
863924
end
864-
@inline function _replace_trailing(n::Int, dims::Tuple{Vararg{Any,N}}) where {N}
865-
ntuple(N) do i
866-
dim_i = getfield(dims, i)
867-
if dim_i isa Tuple
868-
ntuple(length(dim_i)) do j
869-
dim_i_j = getfield(dim_i, j)
870-
dim_i_j > n ? 0 : dim_i_j
871-
end
872-
else
873-
dim_i > n ? 0 : dim_i
874-
end
875-
end
876-
end
877925
@inline function _accum_dims(csdims::NTuple{N,Int}, nd::NTuple{N,Int}) where {N}
878926
ntuple(N) do i
879927
nd_i = getfield(nd, i)
@@ -887,6 +935,30 @@ end
887935
end
888936
end
889937

938+
_lower_info(::IndicesInfo{Np,pdims,cdims,Nc}) where {Np,pdims,cdims,Nc} = Np,pdims,cdims,Nc
939+
940+
ndims_index(@nospecialize(info::IndicesInfo)) = getfield(_lower_info(info), 1)
941+
ndims_shape(@nospecialize(info::IndicesInfo)) = getfield(_lower_info(info), 4)
942+
943+
"""
944+
parentdims(::IndicesInfo) -> Tuple
945+
946+
Returns the parent dimension mapping from `IndicesInfo`.
947+
948+
See also: [`IndicesInfo`](@ref), [`childdims`](@ref)
949+
"""
950+
parentdims(@nospecialize info::IndicesInfo) = getfield(_lower_info(info), 2)
951+
952+
"""
953+
childdims(::IndicesInfo) -> Tuple
954+
955+
Returns the child dimension mapping from `IndicesInfo`.
956+
957+
See also: [`IndicesInfo`](@ref), [`parentdims`](@ref)
958+
"""
959+
childdims(@nospecialize info::IndicesInfo) = getfield(_lower_info(info), 3)
960+
961+
890962
"""
891963
instances_do_not_alias(::Type{T}) -> Bool
892964

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,29 +300,29 @@ end
300300

301301
ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true
302302

303-
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
304-
IndicesInfo{1,(1,),((1,2),)}()
303+
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) isa
304+
IndicesInfo{1,(1,),((1,2),)}
305305

306-
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) == IndicesInfo{1, (1,), (1,)}()
306+
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) isa IndicesInfo{1, (1,), (1,)}
307307

308-
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) == IndicesInfo{2, (:,), (1,)}()
308+
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) isa IndicesInfo{2, (:,), (1,)}
309309

310-
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) == IndicesInfo{1, (1,), (1,)}()
310+
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) isa IndicesInfo{1, (1,), (1,)}
311311

312-
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) == IndicesInfo{2, ((1,2),), ((1, 2),)}()
312+
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) isa IndicesInfo{2, ((1,2),), ((1, 2),)}
313313

314-
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) ==
315-
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
314+
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) isa
315+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}
316316

317-
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) ==
318-
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}()
317+
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) isa
318+
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}
319319

320-
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) ==
321-
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}()
320+
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) isa
321+
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}
322322

323-
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) ==
324-
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}()
323+
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) isa
324+
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}
325325

326-
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) ==
327-
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}()
326+
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) isa
327+
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}
328328
end

src/ArrayInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
8-
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
9-
map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
8+
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
9+
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
1010
stride_preserving_index
1111

1212
# ArrayIndex subtypes and methods

src/dimensions.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11

22

33
_init_dimsmap(x) = _init_dimsmap(IndicesInfo(x))
4-
function _init_dimsmap(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
4+
function _init_dimsmap(@nospecialize info::IndicesInfo)
5+
pdims = parentdims(info)
6+
cdims = childdims(info)
57
ntuple(i -> static(getfield(pdims, i)), length(pdims)),
68
ntuple(i -> static(getfield(cdims, i)), length(pdims))
79
end
@@ -48,7 +50,9 @@ function _sub_axis_map(@nospecialize(T::Type{<:SubArray}), x::Tuple{StaticInt{in
4850
end
4951
end
5052

51-
function map_indices_info(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
53+
function map_indices_info(@nospecialize info::IndicesInfo)
54+
pdims = parentdims(info)
55+
cdims = childdims(info)
5256
ntuple(i -> (static(i), static(getfield(pdims, i)), static(getfield(cdims, i))), length(pdims))
5357
end
5458
function sub_dimnames_map(dnames::Tuple, imap::Tuple)
@@ -82,7 +86,9 @@ from_parent_dims(@nospecialize T::Type{<:MatAdjTrans}) = (StaticInt(2), StaticIn
8286
from_parent_dims(IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices)))
8387
end
8488
# TODO do I need to flatten_tuples here?
85-
function from_parent_dims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
89+
function from_parent_dims(@nospecialize(info::IndicesInfo))
90+
pdims = parentdims(info)
91+
cdims = childdims(info)
8692
ntuple(length(cdims)) do i
8793
pdim_i = getfield(pdims, i)
8894
cdim_i = static(getfield(cdims, i))

0 commit comments

Comments
 (0)