@@ -806,74 +806,122 @@ ndims_shape(x) = ndims_shape(typeof(x))
806
806
end
807
807
808
808
"""
809
+ IndicesInfo{N}(inds::Tuple) -> IndicesInfo{N}(typeof(inds))
809
810
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,pdims,cdims}()
811
+ IndicesInfo(inds::Tuple) -> IndicesInfo(typeof(inds))
812
+ IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{maximum(pdims),pdims,cdims}()
810
813
811
- Provides basic trait information for each index type in in the tuple `T`. `pdims` and
812
- `cdims` are dimension mappings to the parent and child dimensions respectively.
814
+
815
+ Maps a tuple of indices to `N` dimensions. The resulting `pdims` is a tuple where each
816
+ field in `inds` (or field type in `T`) corresponds to the parent dimensions accessed.
817
+ `cdims` similarly maps indices to the resulting child array produced after indexing with
818
+ `inds`. If `N` is not provided then it is assumed that all indices are represented by parent
819
+ dimensions and there are no trailing dimensions accessed. These may be accessed by through
820
+ `parentdims(info::IndicesInfo)` and `childdims(info::IndicesInfo)`. If `N` is not provided,
821
+ it is assumed that no indices are accessing trailing dimensions (which are represented as
822
+ `0` in `parentdims(info)[index_position]`).
823
+
824
+ The the fields and types of `IndicesInfo` should not be accessed directly.
825
+ Instead [`parentdims`](@ref), [`childdims`](@ref), [`ndims_index`](@ref), and
826
+ [`ndims_shape`](@ref) should be used to extract relevant information.
813
827
814
828
# Examples
815
829
816
830
```julia
817
- julia> using ArrayInterfaceCore: IndicesInfo
831
+ julia> using ArrayInterfaceCore: IndicesInfo, parentdims, childdims, ndims_index, ndims_shape
832
+
833
+ julia> info = IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
834
+
835
+ julia> parentdims(info) # the last two indices access trailing dimensions
836
+ (1, (2, 3), 4, 5, 0, 0)
837
+
838
+ julia> childdims(info)
839
+ (1, 2, 0, (3, 4), 5, 0)
840
+
841
+ julia> childdims(info)[3] # index 3 accesses a parent dimension but is dropped in the child array
842
+ 0
818
843
819
- julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
820
- IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
844
+ julia> ndims_index(info)
845
+ 5
846
+
847
+ julia> ndims_shape(info)
848
+ 5
849
+
850
+ julia> info = IndicesInfo(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
851
+
852
+ julia> parentdims(info) # assumed no trailing dimensions
853
+ (1, (2, 3), 4, 5, 6, 7)
854
+
855
+ julia> ndims_index(info) # assumed no trailing dimensions
856
+ 7
821
857
822
858
```
823
859
"""
824
- struct IndicesInfo{N,NI,NS} end
825
- IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
826
- @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
827
- IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
828
- end
829
- function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
830
- _indices_info (
831
- Val {_find_first_true(map_tuple_type(is_splat_index, T))} (),
832
- IndicesInfo {N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)} ()
833
- )
834
- end
835
- function _indices_info (:: Val{nothing} , :: IndicesInfo{1,(1,),NS} ) where {NS}
836
- ns1 = getfield (NS, 1 )
837
- IndicesInfo {1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
838
- end
839
- function _indices_info (:: Val{nothing} , :: IndicesInfo{N,(1,),NS} ) where {N,NS}
840
- ns1 = getfield (NS, 1 )
841
- IndicesInfo {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
842
- end
843
- @inline function _indices_info (:: Val{nothing} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS}
844
- if sum (NI) > N
845
- IndicesInfo {N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)} ()
846
- else
847
- IndicesInfo {N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)} ()
860
+ struct IndicesInfo{Np,pdims,cdims,Nc}
861
+ function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
862
+ SI = _find_first_true (map_tuple_type (is_splat_index, T))
863
+ NI = map_tuple_type (ndims_index, T)
864
+ NS = map_tuple_type (ndims_shape, T)
865
+ if SI === nothing
866
+ ndi = NI
867
+ nds = NS
868
+ else
869
+ nsplat = N - sum (NI)
870
+ if nsplat === 0
871
+ ndi = NI
872
+ nds = NS
873
+ else
874
+ splatmul = max (0 , nsplat + 1 )
875
+ ndi = _map_splats (splatmul, SI, NI)
876
+ nds = _map_splats (splatmul, SI, NS)
877
+ end
878
+ end
879
+ if ndi === (1 ,) && N != = 1
880
+ ns1 = getfield (nds, 1 )
881
+ new {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,),ns1} ()
882
+ else
883
+ nds_cumsum = cumsum (nds)
884
+ if sum (ndi) > N
885
+ init_pdims = _accum_dims (cumsum (ndi), ndi)
886
+ pdims = ntuple (nfields (init_pdims)) do i
887
+ dim_i = getfield (init_pdims, i)
888
+ if dim_i isa Tuple
889
+ ntuple (length (dim_i)) do j
890
+ dim_i_j = getfield (dim_i, j)
891
+ dim_i_j > N ? 0 : dim_i_j
892
+ end
893
+ else
894
+ dim_i > N ? 0 : dim_i
895
+ end
896
+ end
897
+ new {N, pdims, _accum_dims(nds_cumsum, nds), last(nds_cumsum)} ()
898
+ else
899
+ new {N,_accum_dims(cumsum(ndi), ndi), _accum_dims(nds_cumsum, nds), last(nds_cumsum)} ()
900
+ end
901
+ end
848
902
end
849
- end
850
- @inline function _indices_info (:: Val{SI} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS,SI}
851
- nsplat = N - sum (NI)
852
- if nsplat === 0
853
- _indices_info (Val {nothing} (), IndicesInfo {N,NI,NS} ())
854
- else
855
- splatmul = max (0 , nsplat + 1 )
856
- _indices_info (Val {nothing} (), IndicesInfo {N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)} ())
903
+ IndicesInfo {N} (@nospecialize (t:: Tuple )) where {N} = IndicesInfo {N} (typeof (t))
904
+ function IndicesInfo (@nospecialize (T:: Type{<:Tuple} ))
905
+ ndi = map_tuple_type (ndims_index, T)
906
+ nds = map_tuple_type (ndims_shape, T)
907
+ ndi_sum = cumsum (ndi)
908
+ nds_sum = cumsum (nds)
909
+ nf = nfields (ndi_sum)
910
+ pdims = _accum_dims (ndi_sum, ndi)
911
+ cdims = _accum_dims (nds_sum, nds)
912
+ new {getfield(ndi_sum, nf),pdims,cdims,getfield(nds_sum, nf)} ()
857
913
end
914
+ IndicesInfo (@nospecialize t:: Tuple ) = IndicesInfo (typeof (t))
915
+ @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
916
+ IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
917
+ end
918
+ IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
858
919
end
859
920
@inline function _map_splats (nsplat:: Int , splat_index:: Int , dims:: Tuple{Vararg{Int}} )
860
921
ntuple (length (dims)) do i
861
922
i === splat_index ? (nsplat * getfield (dims, i)) : getfield (dims, i)
862
923
end
863
924
end
864
- @inline function _replace_trailing (n:: Int , dims:: Tuple{Vararg{Any,N}} ) where {N}
865
- ntuple (N) do i
866
- dim_i = getfield (dims, i)
867
- if dim_i isa Tuple
868
- ntuple (length (dim_i)) do j
869
- dim_i_j = getfield (dim_i, j)
870
- dim_i_j > n ? 0 : dim_i_j
871
- end
872
- else
873
- dim_i > n ? 0 : dim_i
874
- end
875
- end
876
- end
877
925
@inline function _accum_dims (csdims:: NTuple{N,Int} , nd:: NTuple{N,Int} ) where {N}
878
926
ntuple (N) do i
879
927
nd_i = getfield (nd, i)
887
935
end
888
936
end
889
937
938
+ _lower_info (:: IndicesInfo{Np,pdims,cdims,Nc} ) where {Np,pdims,cdims,Nc} = Np,pdims,cdims,Nc
939
+
940
+ ndims_index (@nospecialize (info:: IndicesInfo )) = getfield (_lower_info (info), 1 )
941
+ ndims_shape (@nospecialize (info:: IndicesInfo )) = getfield (_lower_info (info), 4 )
942
+
943
+ """
944
+ parentdims(::IndicesInfo) -> Tuple
945
+
946
+ Returns the parent dimension mapping from `IndicesInfo`.
947
+
948
+ See also: [`IndicesInfo`](@ref), [`childdims`](@ref)
949
+ """
950
+ parentdims (@nospecialize info:: IndicesInfo ) = getfield (_lower_info (info), 2 )
951
+
952
+ """
953
+ childdims(::IndicesInfo) -> Tuple
954
+
955
+ Returns the child dimension mapping from `IndicesInfo`.
956
+
957
+ See also: [`IndicesInfo`](@ref), [`parentdims`](@ref)
958
+ """
959
+ childdims (@nospecialize info:: IndicesInfo ) = getfield (_lower_info (info), 3 )
960
+
961
+
890
962
"""
891
963
instances_do_not_alias(::Type{T}) -> Bool
892
964
0 commit comments