@@ -9,7 +9,7 @@ using SuiteSparse
9
9
using Base: @assume_effects
10
10
else
11
11
macro assume_effects (_, ex)
12
- Base. @pure ex
12
+ :( Base. @pure $ (ex))
13
13
end
14
14
end
15
15
@@ -22,6 +22,72 @@ const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
22
22
const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
23
23
const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
24
24
25
+ """
26
+ ArrayInterfaceCore.map_tuple_type(f, T::Type{<:Tuple})
27
+
28
+ Returns tuple where each field corresponds to the field type of `T` modified by the function `f`.
29
+
30
+ # Examples
31
+
32
+ ```julia
33
+ julia> ArrayInterfaceCore.map_tuple_type(sqrt, Tuple{1,4,16})
34
+ (1.0, 2.0, 4.0)
35
+
36
+ ```
37
+ """
38
+ function map_tuple_type (f:: F , :: Type{T} ) where {F,T<: Tuple }
39
+ if @generated
40
+ t = Expr (:tuple )
41
+ for i in 1 : fieldcount (T)
42
+ push! (t. args, :(f ($ (fieldtype (T, i)))))
43
+ end
44
+ Expr (:block , Expr (:meta , :inline ), t)
45
+ else
46
+ Tuple (f (fieldtype (T, i)) for i in 1 : fieldcount (T))
47
+ end
48
+ end
49
+
50
+ """
51
+ ArrayInterfaceCore.flatten_tuples(t::Tuple) -> Tuple
52
+
53
+ Flattens any field of `t` that is a tuple. Only direct fields of `t` may be flattened.
54
+
55
+ # Examples
56
+
57
+ ```julia
58
+ julia> ArrayInterfaceCore.flatten_tuples((1, ()))
59
+ (1,)
60
+
61
+ julia> ArrayInterfaceCore.flatten_tuples((1, (2, 3)))
62
+ (1, 2, 3)
63
+
64
+ julia> ArrayInterfaceCore.flatten_tuples((1, (2, (3,))))
65
+ (1, 2, (3,))
66
+
67
+ ```
68
+ """
69
+ @inline function flatten_tuples (t:: Tuple )
70
+ if @generated
71
+ texpr = Expr (:tuple )
72
+ for i in 1 : fieldcount (t)
73
+ p = fieldtype (t, i)
74
+ if p <: Tuple
75
+ for j in 1 : fieldcount (p)
76
+ push! (texpr. args, :(@inbounds (getfield (getfield (t, $ i), $ j))))
77
+ end
78
+ else
79
+ push! (texpr. args, :(@inbounds (getfield (t, $ i))))
80
+ end
81
+ end
82
+ Expr (:block , Expr (:meta , :inline ), texpr)
83
+ else
84
+ _flatten (t)
85
+ end
86
+ end
87
+ _flatten (:: Tuple{} ) = ()
88
+ @inline _flatten (t:: Tuple{Any,Vararg{Any}} ) = (getfield (t, 1 ), _flatten (Base. tail (t))... )
89
+ @inline _flatten (t:: Tuple{Tuple,Vararg{Any}} ) = (getfield (t, 1 )... , _flatten (Base. tail (t))... )
90
+
25
91
"""
26
92
parent_type(::Type{T}) -> Type
27
93
@@ -591,32 +657,100 @@ indexing with an instance of `I`.
591
657
"""
592
658
ndims_shape (T:: DataType ) = ndims_index (T)
593
659
ndims_shape (:: Type{Colon} ) = 1
594
- ndims_shape (T:: Type{<:Base.AbstractCartesianIndex{N}} ) where {N} = ntuple (zero, Val {N} () )
595
- ndims_shape (@nospecialize T:: Type{<:CartesianIndices} ) = ntuple (one, Val {ndims(T)} ())
596
- ndims_shape (@nospecialize T:: Type{<:Number} ) = 0
660
+ ndims_shape (@nospecialize T:: Type{<:CartesianIndices} ) = ndims (T )
661
+ ndims_shape (@nospecialize T:: Type{<:Union{Number,Base.AbstractCartesianIndex}} ) = 0
662
+ ndims_shape (@nospecialize T:: Type{<:AbstractArray{Bool}} ) = 1
597
663
ndims_shape (@nospecialize T:: Type{<:AbstractArray} ) = ndims (T)
598
664
ndims_shape (x) = ndims_shape (typeof (x))
599
665
666
+ @assume_effects :total function _find_first_true (isi:: Tuple{Vararg{Bool,N}} ) where {N}
667
+ for i in 1 : N
668
+ getfield (isi, i) && return i
669
+ end
670
+ return nothing
671
+ end
672
+
600
673
"""
601
- IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS }()
674
+ IndicesInfo{N} (T::Type{<:Tuple}) -> IndicesInfo{N, NI,NS}()
602
675
603
676
Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
604
677
`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
605
678
[`is_splat_index`](@ref) (respectively) for each field of `T`.
679
+
680
+ # Examples
681
+
682
+ ```julia
683
+ julia> using ArrayInterfaceCore: IndicesInfo
684
+
685
+ julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
686
+ IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
687
+
688
+ ```
606
689
"""
607
- struct IndicesInfo{NI,NS,IS} end
608
- IndicesInfo (@nospecialize x:: Tuple ) = IndicesInfo (typeof (x))
609
- @generated function IndicesInfo (:: Type{T} ) where {T<: Tuple }
610
- NI = Expr (:tuple )
611
- NS = Expr (:tuple )
612
- IS = Expr (:tuple )
613
- for i in 1 : fieldcount (T)
614
- T_i = fieldtype (T, i)
615
- push! (NI. args, :(ndims_index ($ (T_i))))
616
- push! (NS. args, :(ndims_shape ($ (T_i))))
617
- push! (IS. args, :(is_splat_index ($ (T_i))))
690
+ struct IndicesInfo{N,NI,NS} end
691
+ IndicesInfo (x:: SubArray ) = IndicesInfo {ndims(parent(x))} (typeof (x. indices))
692
+ @inline function IndicesInfo (@nospecialize T:: Type{<:SubArray} )
693
+ IndicesInfo {ndims(parent_type(T))} (fieldtype (T, :indices ))
694
+ end
695
+ function IndicesInfo {N} (@nospecialize (T:: Type{<:Tuple} )) where {N}
696
+ _indices_info (
697
+ Val {_find_first_true(map_tuple_type(is_splat_index, T))} (),
698
+ IndicesInfo {N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)} ()
699
+ )
700
+ end
701
+ function _indices_info (:: Val{nothing} , :: IndicesInfo{1,(1,),NS} ) where {NS}
702
+ ns1 = getfield (NS, 1 )
703
+ IndicesInfo {1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
704
+ end
705
+ function _indices_info (:: Val{nothing} , :: IndicesInfo{N,(1,),NS} ) where {N,NS}
706
+ ns1 = getfield (NS, 1 )
707
+ IndicesInfo {N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)} ()
708
+ end
709
+ @inline function _indices_info (:: Val{nothing} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS}
710
+ if sum (NI) > N
711
+ IndicesInfo {N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)} ()
712
+ else
713
+ IndicesInfo {N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)} ()
714
+ end
715
+ end
716
+ @inline function _indices_info (:: Val{SI} , :: IndicesInfo{N,NI,NS} ) where {N,NI,NS,SI}
717
+ nsplat = N - sum (NI)
718
+ if nsplat === 0
719
+ _indices_info (Val {nothing} (), IndicesInfo {N,NI,NS} ())
720
+ else
721
+ splatmul = max (0 , nsplat + 1 )
722
+ _indices_info (Val {nothing} (), IndicesInfo {N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)} ())
723
+ end
724
+ end
725
+ @inline function _map_splats (nsplat:: Int , splat_index:: Int , dims:: Tuple{Vararg{Int}} )
726
+ ntuple (length (dims)) do i
727
+ i === splat_index ? (nsplat * getfield (dims, i)) : getfield (dims, i)
728
+ end
729
+ end
730
+ @inline function _replace_trailing (n:: Int , dims:: Tuple{Vararg{Any,N}} ) where {N}
731
+ ntuple (N) do i
732
+ dim_i = getfield (dims, i)
733
+ if dim_i isa Tuple
734
+ ntuple (length (dim_i)) do j
735
+ dim_i_j = getfield (dim_i, j)
736
+ dim_i_j > n ? 0 : dim_i_j
737
+ end
738
+ else
739
+ dim_i > n ? 0 : dim_i
740
+ end
741
+ end
742
+ end
743
+ @inline function _accum_dims (csdims:: NTuple{N,Int} , nd:: NTuple{N,Int} ) where {N}
744
+ ntuple (N) do i
745
+ nd_i = getfield (nd, i)
746
+ if nd_i === 0
747
+ 0
748
+ elseif nd_i === 1
749
+ getfield (csdims, i)
750
+ else
751
+ ntuple (Base. Fix1 (+ , getfield (csdims, i) - nd_i), nd_i)
752
+ end
618
753
end
619
- Expr (:block , Expr (:meta , :inline ), :(IndicesInfo {$(NI),$(NS),$(IS)} ()))
620
754
end
621
755
622
756
"""
0 commit comments