Skip to content

Commit dab38ab

Browse files
committed
Move ndims_index and is_splat_index to core and put docs back in
1 parent 6c9aaeb commit dab38ab

File tree

4 files changed

+69
-23
lines changed

4 files changed

+69
-23
lines changed

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,4 +511,25 @@ known_step(x) = known_step(typeof(x))
511511
known_step(::Type{T}) where {T} = parent_type(T) <: T ? nothing : known_step(parent_type(T))
512512
known_step(::Type{<:AbstractUnitRange}) = 1
513513

514+
"""
515+
is_splat_index(::Type{T}) -> Bool
516+
Returns `static(true)` if `T` is a type that splats across multiple dimensions.
517+
"""
518+
is_splat_index(@nospecialize(x)) = is_splat_index(typeof(x))
519+
is_splat_index(T::Type) = false
520+
521+
"""
522+
ndims_index(::Type{I}) -> Int
523+
524+
Returns the number of dimension that an instance of `I` maps to when indexing. For example,
525+
`CartesianIndex{3}` maps to 3 dimensions. If this method is not explicitly defined, then `1`
526+
is returned.
527+
"""
528+
ndims_index(@nospecialize(i)) = ndims_index(typeof(i))
529+
ndims_index(::Type{I}) where {I} = 1
530+
ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N
531+
ndims_index(::Type{<:AbstractArray{T}}) where {T} = ndims_index(T)
532+
ndims_index(::Type{<:AbstractArray{Bool,N}}) where {N} = N
533+
ndims_index(::Type{<:Base.LogicalIndex{<:Any,<:AbstractArray{Bool,N}}}) where {N} = N
534+
514535
end # module

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, merge_tuple_type,
8-
fast_scalar_indexing, parameterless_type, _is_reshaped
8+
fast_scalar_indexing, parameterless_type, _is_reshaped, ndims_index, is_splat_index
99

1010
# ArrayIndex subtypes and methods
1111
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex

src/dimensions.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
3636
end
3737
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
3838

39+
"""
40+
from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
41+
from_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
42+
43+
Returns the mapping from parent dimensions to child dimensions.
44+
"""
3945
from_parent_dims(x) = from_parent_dims(typeof(x))
4046
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
4147
from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),)
@@ -76,7 +82,6 @@ Compat.@constprop :aggressive function from_parent_dims(::Type{T}, dim::Int)::In
7682
throw_dim_error(T, dim)
7783
end
7884
end
79-
8085
function from_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
8186
if dim > ndims(T)
8287
return static(ndims(parent_type(T)) + dim - ndims(T))
@@ -87,6 +92,12 @@ function from_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
8792
end
8893
end
8994

95+
"""
96+
to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
97+
to_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
98+
99+
Returns the mapping from child dimensions to parent dimensions.
100+
"""
90101
to_parent_dims(x) = to_parent_dims(typeof(x))
91102
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
92103
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
@@ -135,10 +146,21 @@ function to_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
135146
end
136147
end
137148

138-
@inline function has_dimnames(x)
139-
static(known_dimnames(x) !== ntuple(Compat.Returns(:_), Val(ndims(x))))
140-
end
149+
"""
150+
has_dimnames(::Type{T}) -> Bool
141151
152+
Returns `true` if `x` has on or more named dimensions. If all dimensions correspond
153+
to `:_`, then `false` is returned.
154+
"""
155+
@inline has_dimnames(x) = static(known_dimnames(x) !== ntuple(Compat.Returns(:_), Val(ndims(x))))
156+
157+
"""
158+
known_dimnames(::Type{T}) -> Tuple{Vararg{Union{Symbol,Nothing}}}
159+
known_dimnames(::Type{T}, dim::Union{Int,StaticInt}) -> Union{Symbol,Nothing}
160+
161+
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
162+
have a name.
163+
"""
142164
@inline known_dimnames(x, dim::Integer) = _known_dimname(known_dimnames(x), canonicalize(dim))
143165
known_dimnames(x) = known_dimnames(typeof(x))
144166
known_dimnames(::Type{T}) where {T} = _known_dimnames(T, parent_type(T))
@@ -155,6 +177,13 @@ end
155177
end
156178
@inline _inbounds_known_dimname(x, dim) = @inbounds(_known_dimname(x, dim))
157179

180+
"""
181+
dimnames(x) -> Tuple{Vararg{Union{Symbol,StaticSymbol}}}
182+
dimnames(x, dim::Union{Int,StaticInt}) -> Union{Symbol,StaticSymbol}
183+
184+
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
185+
have a name.
186+
"""
158187
@inline dimnames(x, dim::Integer) = _dimname(dimnames(x), canonicalize(dim))
159188
@inline dimnames(x) = _dimnames(has_parent(x), x)
160189
@inline function _dimnames(::True, x)
@@ -169,6 +198,11 @@ _dimnames(::False, x) = ntuple(_->static(:_), Val(ndims(x)))
169198
end
170199
@inline _inbounds_dimname(x, dim) = @inbounds(_dimname(x, dim))
171200

201+
"""
202+
to_dims(x, dim) -> Union{Int,StaticInt}
203+
204+
This returns the dimension(s) of `x` corresponding to `dim`.
205+
"""
172206
to_dims(x, dim::Colon) = dim
173207
to_dims(x, dim::Integer) = canonicalize(dim)
174208
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)

src/indexing.jl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
1-
is_splat_index(@nospecialize(x)) = is_splat_index(typeof(x))
2-
is_splat_index(::Type{T}) where {T} = static(false)
3-
_is_splat(::Type{I}, i::StaticInt) where {I} = is_splat_index(field_type(I, i))
41

5-
"""
6-
ndims_index(::Type{I}) -> StaticInt
7-
8-
Returns the number of dimension that an instance of `I` maps to when indexing. For example,
9-
`CartesianIndex{3}` maps to 3 dimensions. If this method is not explicitly defined, then `1`
10-
is returned.
2+
function _is_splat(::Type{I}, i::StaticInt) where {I}
3+
if dynamic(is_splat_index(field_type(I, i)))
4+
True()
5+
else
6+
False()
7+
end
8+
end
119

12-
"""
13-
ndims_index(@nospecialize(i)) = ndims_index(typeof(i))
14-
ndims_index(::Type{I}) where {I} = static(1)
15-
ndims_index(::Type{<:AbstractCartesianIndex{N}}) where {N} = static(N)
16-
ndims_index(::Type{<:AbstractArray{T}}) where {T} = ndims_index(T)
17-
ndims_index(::Type{<:AbstractArray{Bool,N}}) where {N} = static(N)
18-
ndims_index(::Type{<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}}) where {N} = static(N)
19-
_ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(field_type(I, i))
10+
_ndims_index(::Type{I}, i::StaticInt) where {I} = StaticInt(ndims_index(field_type(I, i)))
2011

2112
"""
2213
to_indices(A, I::Tuple) -> Tuple
@@ -238,7 +229,7 @@ indices calling [`to_axis`](@ref).
238229
end
239230
# drop this dimension
240231
to_axes(A, a::Tuple, i::Tuple{<:Integer,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
241-
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(ndims_index(I), A, a, i)
232+
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
242233
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
243234
return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...)
244235
end

0 commit comments

Comments
 (0)