Skip to content

Commit 702e092

Browse files
authored
New dimension mapping (#327)
* replacing old to_parent_dims with more robust interface * Replace generated SubArray methods * Versions without `@assume_effects` weren't getting the effects of `@pure` * Add in select implementations of to/from_parent_dims with warning * Integrate dimension into `IndicesInfo`
1 parent 1c6c725 commit 702e092

File tree

12 files changed

+459
-500
lines changed

12 files changed

+459
-500
lines changed

lib/ArrayInterfaceCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterfaceCore"
22
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
3-
version = "0.1.13"
3+
version = "0.1.14"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 151 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using SuiteSparse
99
using Base: @assume_effects
1010
else
1111
macro assume_effects(_, ex)
12-
Base.@pure ex
12+
:(Base.@pure $(ex))
1313
end
1414
end
1515

@@ -22,6 +22,72 @@ const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
2222
const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
2323
const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
2424

25+
"""
26+
ArrayInterfaceCore.map_tuple_type(f, T::Type{<:Tuple})
27+
28+
Returns tuple where each field corresponds to the field type of `T` modified by the function `f`.
29+
30+
# Examples
31+
32+
```julia
33+
julia> ArrayInterfaceCore.map_tuple_type(sqrt, Tuple{1,4,16})
34+
(1.0, 2.0, 4.0)
35+
36+
```
37+
"""
38+
function map_tuple_type(f::F, ::Type{T}) where {F,T<:Tuple}
39+
if @generated
40+
t = Expr(:tuple)
41+
for i in 1:fieldcount(T)
42+
push!(t.args, :(f($(fieldtype(T, i)))))
43+
end
44+
Expr(:block, Expr(:meta, :inline), t)
45+
else
46+
Tuple(f(fieldtype(T, i)) for i in 1:fieldcount(T))
47+
end
48+
end
49+
50+
"""
51+
ArrayInterfaceCore.flatten_tuples(t::Tuple) -> Tuple
52+
53+
Flattens any field of `t` that is a tuple. Only direct fields of `t` may be flattened.
54+
55+
# Examples
56+
57+
```julia
58+
julia> ArrayInterfaceCore.flatten_tuples((1, ()))
59+
(1,)
60+
61+
julia> ArrayInterfaceCore.flatten_tuples((1, (2, 3)))
62+
(1, 2, 3)
63+
64+
julia> ArrayInterfaceCore.flatten_tuples((1, (2, (3,))))
65+
(1, 2, (3,))
66+
67+
```
68+
"""
69+
@inline function flatten_tuples(t::Tuple)
70+
if @generated
71+
texpr = Expr(:tuple)
72+
for i in 1:fieldcount(t)
73+
p = fieldtype(t, i)
74+
if p <: Tuple
75+
for j in 1:fieldcount(p)
76+
push!(texpr.args, :(@inbounds(getfield(getfield(t, $i), $j))))
77+
end
78+
else
79+
push!(texpr.args, :(@inbounds(getfield(t, $i))))
80+
end
81+
end
82+
Expr(:block, Expr(:meta, :inline), texpr)
83+
else
84+
_flatten(t)
85+
end
86+
end
87+
_flatten(::Tuple{}) = ()
88+
@inline _flatten(t::Tuple{Any,Vararg{Any}}) = (getfield(t, 1), _flatten(Base.tail(t))...)
89+
@inline _flatten(t::Tuple{Tuple,Vararg{Any}}) = (getfield(t, 1)..., _flatten(Base.tail(t))...)
90+
2591
"""
2692
parent_type(::Type{T}) -> Type
2793
@@ -591,32 +657,100 @@ indexing with an instance of `I`.
591657
"""
592658
ndims_shape(T::DataType) = ndims_index(T)
593659
ndims_shape(::Type{Colon}) = 1
594-
ndims_shape(T::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = ntuple(zero, Val{N}())
595-
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ntuple(one, Val{ndims(T)}())
596-
ndims_shape(@nospecialize T::Type{<:Number}) = 0
660+
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
661+
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0
662+
ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
597663
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
598664
ndims_shape(x) = ndims_shape(typeof(x))
599665

666+
@assume_effects :total function _find_first_true(isi::Tuple{Vararg{Bool,N}}) where {N}
667+
for i in 1:N
668+
getfield(isi, i) && return i
669+
end
670+
return nothing
671+
end
672+
600673
"""
601-
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS}()
674+
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,NI,NS}()
602675
603676
Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
604677
`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
605678
[`is_splat_index`](@ref) (respectively) for each field of `T`.
679+
680+
# Examples
681+
682+
```julia
683+
julia> using ArrayInterfaceCore: IndicesInfo
684+
685+
julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
686+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
687+
688+
```
606689
"""
607-
struct IndicesInfo{NI,NS,IS} end
608-
IndicesInfo(@nospecialize x::Tuple) = IndicesInfo(typeof(x))
609-
@generated function IndicesInfo(::Type{T}) where {T<:Tuple}
610-
NI = Expr(:tuple)
611-
NS = Expr(:tuple)
612-
IS = Expr(:tuple)
613-
for i in 1:fieldcount(T)
614-
T_i = fieldtype(T, i)
615-
push!(NI.args, :(ndims_index($(T_i))))
616-
push!(NS.args, :(ndims_shape($(T_i))))
617-
push!(IS.args, :(is_splat_index($(T_i))))
690+
struct IndicesInfo{N,NI,NS} end
691+
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
692+
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
693+
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
694+
end
695+
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
696+
_indices_info(
697+
Val{_find_first_true(map_tuple_type(is_splat_index, T))}(),
698+
IndicesInfo{N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)}()
699+
)
700+
end
701+
function _indices_info(::Val{nothing}, ::IndicesInfo{1,(1,),NS}) where {NS}
702+
ns1 = getfield(NS, 1)
703+
IndicesInfo{1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
704+
end
705+
function _indices_info(::Val{nothing}, ::IndicesInfo{N,(1,),NS}) where {N,NS}
706+
ns1 = getfield(NS, 1)
707+
IndicesInfo{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
708+
end
709+
@inline function _indices_info(::Val{nothing}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS}
710+
if sum(NI) > N
711+
IndicesInfo{N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)}()
712+
else
713+
IndicesInfo{N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)}()
714+
end
715+
end
716+
@inline function _indices_info(::Val{SI}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS,SI}
717+
nsplat = N - sum(NI)
718+
if nsplat === 0
719+
_indices_info(Val{nothing}(), IndicesInfo{N,NI,NS}())
720+
else
721+
splatmul = max(0, nsplat + 1)
722+
_indices_info(Val{nothing}(), IndicesInfo{N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)}())
723+
end
724+
end
725+
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})
726+
ntuple(length(dims)) do i
727+
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
728+
end
729+
end
730+
@inline function _replace_trailing(n::Int, dims::Tuple{Vararg{Any,N}}) where {N}
731+
ntuple(N) do i
732+
dim_i = getfield(dims, i)
733+
if dim_i isa Tuple
734+
ntuple(length(dim_i)) do j
735+
dim_i_j = getfield(dim_i, j)
736+
dim_i_j > n ? 0 : dim_i_j
737+
end
738+
else
739+
dim_i > n ? 0 : dim_i
740+
end
741+
end
742+
end
743+
@inline function _accum_dims(csdims::NTuple{N,Int}, nd::NTuple{N,Int}) where {N}
744+
ntuple(N) do i
745+
nd_i = getfield(nd, i)
746+
if nd_i === 0
747+
0
748+
elseif nd_i === 1
749+
getfield(csdims, i)
750+
else
751+
ntuple(Base.Fix1(+, getfield(csdims, i) - nd_i), nd_i)
752+
end
618753
end
619-
Expr(:block, Expr(:meta, :inline), :(IndicesInfo{$(NI),$(NS),$(IS)}()))
620754
end
621755

622756
"""

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ArrayInterfaceCore
22
using ArrayInterfaceCore: zeromatrix
33
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
4-
parent_type, zeromatrix
4+
parent_type, zeromatrix, IndicesInfo
55
using Base: setindex
66
using LinearAlgebra
77
using Random
@@ -271,8 +271,8 @@ end
271271
@testset "ndims_shape" begin
272272
@test @inferred(ArrayInterfaceCore.ndims_shape(1)) === 0
273273
@test @inferred(ArrayInterfaceCore.ndims_shape(:)) === 1
274-
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndex(1, 2))) === (0, 0)
275-
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndices((2,2)))) === (1, 1)
274+
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndex(1, 2))) === 0
275+
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndices((2,2)))) === 2
276276
@test @inferred(ArrayInterfaceCore.ndims_shape([1 1])) === 2
277277
end
278278

@@ -293,3 +293,36 @@ end
293293
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(fill(rand(4,4),4,4)', 2:3, 1:2)))
294294
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(rand(4,4)', StepRangeLen(1,0,5), 1:2)))
295295
end
296+
297+
@testset "IndicesInfo" begin
298+
299+
struct SplatFirst end
300+
301+
ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true
302+
303+
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
304+
IndicesInfo{1,(1,),((1,2),)}()
305+
306+
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) == IndicesInfo{1, (1,), (1,)}()
307+
308+
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) == IndicesInfo{2, (:,), (1,)}()
309+
310+
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) == IndicesInfo{1, (1,), (1,)}()
311+
312+
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) == IndicesInfo{2, ((1,2),), ((1, 2),)}()
313+
314+
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) ==
315+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
316+
317+
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) ==
318+
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}()
319+
320+
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) ==
321+
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}()
322+
323+
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) ==
324+
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}()
325+
326+
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) ==
327+
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}()
328+
end

src/ArrayInterface.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
8-
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo
8+
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
9+
map_tuple_type, flatten_tuples, GetIndex
910

1011
# ArrayIndex subtypes and methods
1112
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
@@ -34,8 +35,6 @@ using LinearAlgebra
3435

3536
import Compat
3637

37-
n_of_x(::StaticInt{N}, x::X) where {N,X} = ntuple(Compat.Returns(x), Val{N}())
38-
3938
_add1(@nospecialize x) = x + oneunit(x)
4039
_sub1(@nospecialize x) = x - oneunit(x)
4140

0 commit comments

Comments
 (0)