Skip to content

Commit 4d205d7

Browse files
authored
Merge pull request #83 from Tokazama/master
Formalize interface for keyword arguments in indexing
2 parents 7ebbfc7 + 09b17fa commit 4d205d7

File tree

7 files changed

+356
-40
lines changed

7 files changed

+356
-40
lines changed

src/ArrayInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ parameterless_type(x) = parameterless_type(typeof(x))
1111
parameterless_type(x::Type) = __parameterless_type(x)
1212

1313
"""
14-
parent_type(x)
14+
parent_type(::Type{T})
1515
1616
Returns the parent array that `x` wraps.
1717
"""
@@ -889,6 +889,7 @@ end
889889

890890
include("static.jl")
891891
include("ranges.jl")
892+
include("dimensions.jl")
892893
include("indexing.jl")
893894
include("stridelayout.jl")
894895

src/dimensions.jl

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
2+
"""
3+
has_dimnames(::Type{T}) -> Bool
4+
5+
Returns `true` if `x` has names for each dimension.
6+
"""
7+
@inline has_dimnames(x) = has_dimnames(typeof(x))
8+
function has_dimnames(::Type{T}) where {T}
9+
if parent_type(T) <: T
10+
return false
11+
else
12+
return has_dimnames(parent_type(T))
13+
end
14+
end
15+
16+
"""
17+
dimnames(::Type{T}) -> Tuple{Vararg{Symbol}}
18+
dimnames(::Type{T}, d) -> Symbol
19+
20+
Return the names of the dimensions for `x`.
21+
"""
22+
@inline dimnames(x) = dimnames(typeof(x))
23+
@inline dimnames(x, i::Integer) = dimnames(typeof(x), i)
24+
@inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d))
25+
@inline function dimnames(::Type{T}) where {T}
26+
if parent_type(T) <: T
27+
return ntuple(i -> :_, Val(ndims(T)))
28+
else
29+
return dimnames(parent_type(T))
30+
end
31+
end
32+
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
33+
return _transpose_dimnames(dimnames(parent_type(T)))
34+
end
35+
_transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
36+
_transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
37+
38+
@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
39+
return map(i -> dimnames(parent_type(T), i), I)
40+
end
41+
function dimnames(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
42+
return _sub_array_dimnames(Val(dimnames(P)), Val(argdims(P, I)))
43+
end
44+
@generated function _sub_array_dimnames(::Val{L}, ::Val{I}) where {L,I}
45+
e = Expr(:tuple)
46+
nl = length(L)
47+
for i in 1:length(I)
48+
if I[i] > 0
49+
if nl < i
50+
push!(e.args, QuoteNode(:_))
51+
else
52+
push!(e.args, QuoteNode(L[i]))
53+
end
54+
end
55+
end
56+
return e
57+
end
58+
59+
"""
60+
to_dims(x[, d])
61+
62+
This returns the dimension(s) of `x` corresponding to `d`.
63+
"""
64+
to_dims(x, d::Integer) = Int(d)
65+
to_dims(x, d::Colon) = d # `:` is the default for most methods that take `dims`
66+
@inline to_dims(x, d::Tuple) = map(i -> to_dims(x, i), d)
67+
@inline function to_dims(x, d::Symbol)::Int
68+
i = _sym_to_dim(dimnames(x), d)
69+
if i === 0
70+
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(dimnames(x)))"))
71+
end
72+
return i
73+
end
74+
Base.@pure function _sym_to_dim(x::Tuple{Vararg{Symbol,N}}, sym::Symbol) where {N}
75+
for i in 1:N
76+
getfield(x, i) === sym && return i
77+
end
78+
return 0
79+
end
80+
81+
"""
82+
tuple_issubset
83+
84+
A version of `issubset` sepecifically for `Tuple`s of `Symbol`s, that is `@pure`.
85+
This helps it get optimised out of existance. It is less of an abuse of `@pure` than
86+
most of the stuff for making `NamedTuples` work.
87+
"""
88+
Base.@pure function tuple_issubset(
89+
lhs::Tuple{Vararg{Symbol,N}}, rhs::Tuple{Vararg{Symbol,M}},
90+
) where {N,M}
91+
N <= M || return false
92+
for a in lhs
93+
found = false
94+
for b in rhs
95+
found |= a === b
96+
end
97+
found || return false
98+
end
99+
return true
100+
end
101+
102+
"""
103+
order_named_inds(Val(names); kwargs...)
104+
order_named_inds(Val(names), namedtuple)
105+
106+
Returns the tuple of index values for an array with `names`, when indexed by keywords.
107+
Any dimensions not fixed are given as `:`, to make a slice.
108+
An error is thrown if any keywords are used which do not occur in `nda`'s names.
109+
"""
110+
@inline function order_named_inds(val::Val{L}; kwargs...) where {L}
111+
if isempty(kwargs)
112+
return ()
113+
else
114+
return order_named_inds(val, kwargs.data)
115+
end
116+
end
117+
@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K}
118+
tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K"))
119+
exs = map(L) do n
120+
if Base.sym_in(n, K)
121+
qn = QuoteNode(n)
122+
:(getfield(ni, $qn))
123+
else
124+
:(Colon())
125+
end
126+
end
127+
return Expr(:tuple, exs...)
128+
end
129+
130+
"""
131+
size(A)
132+
133+
Returns the size of `A`. If the size of any axes are known at compile time,
134+
these should be returned as `Static` numbers. For example:
135+
```julia
136+
julia> using StaticArrays, ArrayInterface
137+
138+
julia> A = @SMatrix rand(3,4);
139+
140+
julia> ArrayInterface.size(A)
141+
(StaticInt{3}(), StaticInt{4}())
142+
```
143+
"""
144+
size(A) = Base.size(A)
145+
size(A, d) = Base.size(A, to_dims(A, d))
146+
147+
"""
148+
axes(A, d)
149+
150+
Return a valid range that maps to each index along dimension `d` of `A`.
151+
"""
152+
axes(A, d) = Base.axes(A, to_dims(A, d))
153+
154+
"""
155+
axes(A)
156+
157+
Return a tuple of ranges where each range maps to each element along a dimension of `A`.
158+
"""
159+
axes(A) = Base.axes(A)
160+

src/indexing.jl

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,29 @@ Changing indexing based on a given argument from `args` should be done through
390390
[`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref).
391391
"""
392392
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args))
393+
@propagate_inbounds function getindex(A; kwargs...)
394+
if has_dimnames(A)
395+
return A[order_named_inds(Val(dimnames(A)); kwargs...)...]
396+
else
397+
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
398+
end
399+
end
393400

394401
"""
395402
unsafe_getindex(A, inds)
396403
397404
Indexes into `A` given `inds`. This method assumes that `inds` have already been
398405
bounds-checked.
399406
"""
400-
unsafe_getindex(A, inds) = unsafe_getindex(UnsafeIndex(A, inds), A, inds)
401-
unsafe_getindex(::UnsafeGetElement, A, inds) = unsafe_get_element(A, inds)
402-
unsafe_getindex(::UnsafeGetCollection, A, inds) = unsafe_get_collection(A, inds)
407+
function unsafe_getindex(A, inds; kwargs...)
408+
return unsafe_getindex(UnsafeIndex(A, inds), A, inds; kwargs...)
409+
end
410+
function unsafe_getindex(::UnsafeGetElement, A, inds; kwargs...)
411+
return unsafe_get_element(A, inds; kwargs...)
412+
end
413+
function unsafe_getindex(::UnsafeGetCollection, A, inds; kwargs...)
414+
return unsafe_get_collection(A, inds; kwargs...)
415+
end
403416

404417
"""
405418
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
@@ -408,9 +421,7 @@ Returns an element of `A` at the indices `inds`. This method assumes all `inds`
408421
have been checked for being in bounds. Any new array type using `ArrayInterface.getindex`
409422
must define `unsafe_get_element(::NewArrayType, inds)`.
410423
"""
411-
function unsafe_get_element(A, inds)
412-
throw(MethodError(unsafe_getindex, (A, inds)))
413-
end
424+
unsafe_get_element(A, inds; kwargs...) = throw(MethodError(unsafe_getindex, (A, inds)))
414425
function unsafe_get_element(A::Array, inds)
415426
if length(inds) === 0
416427
return Base.arrayref(false, A, 1)
@@ -433,11 +444,11 @@ end
433444
434445
Returns a collection of `A` given `inds`. `inds` is assumed to have been bounds-checked.
435446
"""
436-
function unsafe_get_collection(A, inds)
447+
function unsafe_get_collection(A, inds; kwargs...)
437448
axs = to_axes(A, inds)
438449
dest = similar(A, axs)
439450
if map(Base.unsafe_length, axes(dest)) == map(Base.unsafe_length, axs)
440-
Base._unsafe_getindex!(dest, A, inds...) # usually a generated function, don't allow it to impact inference result
451+
_unsafe_getindex!(dest, A, inds...; kwargs...) # usually a generated function, don't allow it to impact inference result
441452
else
442453
Base.throw_checksize_error(dest, axs)
443454
end
@@ -490,16 +501,29 @@ Store the given values at the given key or index within a collection.
490501
"elements after construction.")
491502
end
492503
end
504+
@propagate_inbounds function setindex!(A, val; kwargs...)
505+
if has_dimnames(A)
506+
A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val
507+
else
508+
return unsafe_setindex!(A, val, to_indices(A, ()); kwargs...)
509+
end
510+
end
493511

494512
"""
495-
unsafe_setindex!(A, val, inds::Tuple)
513+
unsafe_setindex!(A, val, inds::Tuple; kwargs...)
496514
497515
Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
498516
bounds-checked. This step of the processing pipeline can be customized by:
499517
"""
500-
unsafe_setindex!(A, val, inds::Tuple) = unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds)
501-
unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple) = unsafe_set_element!(A, val, inds)
502-
unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple) = unsafe_set_collection!(A, val, inds)
518+
function unsafe_setindex!(A, val, inds::Tuple; kwargs...)
519+
return unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds; kwargs...)
520+
end
521+
function unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple; kwargs...)
522+
return unsafe_set_element!(A, val, inds; kwargs...)
523+
end
524+
function unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple; kwargs...)
525+
return unsafe_set_collection!(A, val, inds; kwargs...)
526+
end
503527

504528
"""
505529
unsafe_set_element!(A, val, inds::Tuple)
@@ -508,7 +532,7 @@ Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds
508532
have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!`
509533
must define `unsafe_set_element!(::NewArrayType, val, inds)`.
510534
"""
511-
function unsafe_set_element!(A, val, inds)
535+
function unsafe_set_element!(A, val, inds; kwargs...)
512536
throw(MethodError(unsafe_set_element!, (A, val, inds)))
513537
end
514538
function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T}
@@ -527,6 +551,60 @@ end
527551
528552
Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked.
529553
"""
530-
@inline function unsafe_set_collection!(A, val, inds)
531-
return Base._unsafe_setindex!(IndexStyle(A), A, val, inds...)
554+
@inline function unsafe_set_collection!(A, val, inds; kwargs...)
555+
return _unsafe_setindex!(IndexStyle(A), A, val, inds...; kwargs...)
556+
end
557+
558+
559+
# these let us use `@ncall` on getindex/setindex! that have kwargs
560+
function _setindex_kwargs!(x, val, kwargs, args...)
561+
@inbounds setindex!(x, val, args...; kwargs...)
562+
end
563+
function _getindex_kwargs(x, kwargs, args...)
564+
@inbounds getindex(x, args...; kwargs...)
565+
end
566+
567+
function _generate_unsafe_getindex!_body(N::Int)
568+
quote
569+
Base.@_inline_meta
570+
D = eachindex(dest)
571+
Dy = iterate(D)
572+
@inbounds Base.Cartesian.@nloops $N j d->I[d] begin
573+
# This condition is never hit, but at the moment
574+
# the optimizer is not clever enough to split the union without it
575+
Dy === nothing && return dest
576+
(idx, state) = Dy
577+
dest[idx] = Base.Cartesian.@ncall $N _getindex_kwargs src kwargs j
578+
Dy = iterate(D, state)
579+
end
580+
return dest
581+
end
532582
end
583+
584+
function _generate_unsafe_setindex!_body(N::Int)
585+
quote
586+
x′ = Base.unalias(A, x)
587+
Base.Cartesian.@nexprs $N d->(I_d = Base.unalias(A, I[d]))
588+
idxlens = Base.Cartesian.@ncall $N Base.index_lengths I
589+
Base.Cartesian.@ncall $N Base.setindex_shape_check x′ (d->idxlens[d])
590+
Xy = iterate(x′)
591+
@inbounds Base.Cartesian.@nloops $N i d->I_d begin
592+
# This is never reached, but serves as an assumption for
593+
# the optimizer that it does not need to emit error paths
594+
Xy === nothing && break
595+
(val, state) = Xy
596+
Base.Cartesian.@ncall $N _setindex_kwargs! A val kwargs i
597+
Xy = iterate(x′, state)
598+
end
599+
A
600+
end
601+
end
602+
603+
@generated function _unsafe_getindex!(dest::AbstractArray, src::AbstractArray, I::Vararg{Union{Real, AbstractArray}, N}; kwargs...) where N
604+
_generate_unsafe_getindex!_body(N)
605+
end
606+
607+
@generated function _unsafe_setindex!(::IndexStyle, A::AbstractArray, x, I::Vararg{Union{Real,AbstractArray}, N}; kwargs...) where N
608+
_generate_unsafe_setindex!_body(N)
609+
end
610+

src/ranges.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@ Otherwise, return `nothing`.
99
@test isone(known_first(typeof(Base.OneTo(4))))
1010
"""
1111
known_first(x) = known_first(typeof(x))
12-
known_first(::Type{T}) where {T} = nothing
12+
function known_first(::Type{T}) where {T}
13+
if parent_type(T) <: T
14+
return nothing
15+
else
16+
return known_first(parent_type(T))
17+
end
18+
end
1319
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
14-
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))
1520

1621
"""
1722
known_last(::Type{T})
@@ -24,8 +29,13 @@ using StaticArrays
2429
@test known_last(typeof(SOneTo(4))) == 4
2530
"""
2631
known_last(x) = known_last(typeof(x))
27-
known_last(::Type{T}) where {T} = nothing
28-
known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T))
32+
function known_last(::Type{T}) where {T}
33+
if parent_type(T) <: T
34+
return nothing
35+
else
36+
return known_last(parent_type(T))
37+
end
38+
end
2939

3040
"""
3141
known_step(::Type{T})
@@ -37,11 +47,15 @@ Otherwise, return `nothing`.
3747
@test isone(known_step(typeof(1:4)))
3848
"""
3949
known_step(x) = known_step(typeof(x))
40-
known_step(::Type{T}) where {T} = nothing
50+
function known_step(::Type{T}) where {T}
51+
if parent_type(T) <: T
52+
return nothing
53+
else
54+
return known_step(parent_type(T))
55+
end
56+
end
4157
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4258

43-
# add methods to support ArrayInterface
44-
4559
"""
4660
OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int}
4761
@@ -419,3 +433,4 @@ end
419433
lst = _try_static(static_last(x), static_last(y))
420434
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
421435
end
436+

0 commit comments

Comments
 (0)