Skip to content

Commit 4b8c553

Browse files
authored
Cleaning up indexing and dimension mapping code. (#181)
* Clean up index-dims mapping * Clean up dimension/indexing related code/docs * Realized that a lot of things I had documented internally are now in the online docs, likely making them appear more stable and public than they are right now. * `argdims` was ambiguously documented and not used elsewhere so I split it into two separate/more useful methos: ndims_index and ndims_subset * Deleted/consolidated a bunch of indexing code * Lispy approach to ndims on tuple of inds * Drop support for contectual index mapping for now. * Clean up and more tests * removed ndims_subset for now. * added canonicalize tests * version bump
1 parent 8495524 commit 4b8c553

File tree

6 files changed

+136
-190
lines changed

6 files changed

+136
-190
lines changed

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,9 +662,9 @@ function _is_lazy_conjugate(::Type{T}, isconj) where {T <: Adjoint}
662662
end
663663

664664
include("ranges.jl")
665-
include("dimensions.jl")
666665
include("axes.jl")
667666
include("size.jl")
667+
include("dimensions.jl")
668668
include("indexing.jl")
669669
include("stridelayout.jl")
670670
include("broadcast.jl")

src/dimensions.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,22 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
2222
end
2323
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
2424

25+
#=
26+
ndims_index(::Type{I})::StaticInt
27+
28+
The number of dimensions an instance of `I` maps to when indexing an instance of `A`.
29+
=#
30+
ndims_index(i) = ndims_index(typeof(i))
31+
ndims_index(::Type{I}) where {I} = static(1)
32+
ndims_index(::Type{I}) where {N,I<:AbstractCartesianIndex{N}} = static(N)
33+
ndims_index(::Type{I}) where {I<:AbstractArray} = ndims_index(eltype(I))
34+
ndims_index(::Type{I}) where {I<:AbstractArray{Bool}} = static(ndims(I))
35+
ndims_index(::Type{I}) where {N,I<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N)
36+
_ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(_get_tuple(I, i))
37+
ndims_index(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_ndims_index, nstatic(Val(N)), I)
38+
2539
"""
26-
from_parent_dims(::Type{T}) -> Tuple
40+
from_parent_dims(::Type{T})::Tuple{Vararg{Union{Int,StaticInt}}}
2741
2842
Returns the mapping from parent dimensions to child dimensions.
2943
"""
@@ -37,17 +51,16 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A
3751
dim_i = 1
3852
for i in 1:ndims(A)
3953
p = I.parameters[i]
40-
if argdims(A, p) > 0
54+
if p <: Integer
55+
push!(out.args, :(StaticInt(0)))
56+
else
4157
push!(out.args, :(StaticInt($dim_i)))
4258
dim_i += 1
43-
else
44-
push!(out.args, :(StaticInt(0)))
4559
end
4660
end
4761
out
4862
end
4963
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))
50-
5164
function from_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
5265
if !_is_reshaped(R) || sizeof(S) === sizeof(T)
5366
return nstatic(Val(ndims(A)))
@@ -59,7 +72,7 @@ function from_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}
5972
end
6073

6174
"""
62-
from_parent_dims(::Type{T}, dim) -> Integer
75+
from_parent_dims(::Type{T}, dim)::Union{Int,StaticInt}
6376
6477
Returns the mapping from child dimensions to parent dimensions.
6578
"""
@@ -85,7 +98,7 @@ function from_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
8598
end
8699

87100
"""
88-
to_parent_dims(::Type{T}) -> Tuple
101+
to_parent_dims(::Type{T})::Tuple{Vararg{Union{Int,StaticInt}}}
89102
90103
Returns the mapping from child dimensions to parent dimensions.
91104
"""
@@ -98,7 +111,7 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
98111
out = Expr(:tuple)
99112
n = 1
100113
for p in I.parameters
101-
if argdims(A, p) > 0
114+
if !(p <: Integer)
102115
push!(out.args, :(StaticInt($n)))
103116
end
104117
n += 1
@@ -117,7 +130,7 @@ function to_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
117130
end
118131

119132
"""
120-
to_parent_dims(::Type{T}, dim) -> Integer
133+
to_parent_dims(::Type{T}, dim)::Union{Int,StaticInt}
121134
122135
Returns the mapping from child dimensions to parent dimensions.
123136
"""
@@ -143,7 +156,7 @@ function to_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
143156
end
144157

145158
"""
146-
has_dimnames(::Type{T}) -> Bool
159+
has_dimnames(::Type{T})::Bool
147160
148161
Returns `true` if `x` has names for each dimension.
149162
"""
@@ -160,8 +173,8 @@ end
160173
const SUnderscore = StaticSymbol(:_)
161174

162175
"""
163-
dimnames(::Type{T}) -> Tuple{Vararg{StaticSymbol}}
164-
dimnames(::Type{T}, dim) -> StaticSymbol
176+
dimnames(::Type{T})::Tuple{Vararg{StaticSymbol}}
177+
dimnames(::Type{T}, dim)::StaticSymbol
165178
166179
Return the names of the dimensions for `x`.
167180
"""
@@ -191,7 +204,7 @@ function dimnames(::Type{T}) where {T<:SubArray}
191204
end
192205

193206
"""
194-
to_dims(::Type{T}, dim) -> Integer
207+
to_dims(::Type{T}, dim)::Union{Int,StaticInt}
195208
196209
This returns the dimension(s) of `x` corresponding to `d`.
197210
"""

0 commit comments

Comments
 (0)