Skip to content

Commit 9d0843f

Browse files
authored
Clean @pure issues in "dimensions.jl" and "static.jl" (#119)
* Cleaned up code in src/dimensions.jl and test/dimensions.jl * Add StaticSymbol * Incorporate @aggressive_constprop and README docs
1 parent db2cfb4 commit 9d0843f

File tree

11 files changed

+332
-222
lines changed

11 files changed

+332
-222
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.0.2"
3+
version = "3.1"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ If unknown, it returns `nothing`.
152152

153153
## contiguous_axis_indicator(::Type{T})
154154

155-
Returns a tuple of boolean `Val`s indicating whether that axis is contiguous.
155+
Returns a tuple of boolean `StaticBool`s indicating whether that axis is contiguous.
156156

157157
## contiguous_batch_size(::Type{T})
158158

@@ -167,7 +167,7 @@ Returns the rank of each stride.
167167

168168
## is_column_major(A)
169169

170-
Returns a `Val{true}()` if `A` is column major, and a `Val{false}()` otherwise.`
170+
Returns a `True` if `A` is column major, and a `True/False` otherwise.
171171

172172
## dense_dims(::Type{T})
173173
Returns a tuple of indicators for whether each axis is dense.
@@ -208,6 +208,10 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1)
208208

209209
Is the function `f` whitelisted for `LoopVectorization.@avx`?
210210

211+
## static(x)
212+
Returns a static form of `x`. If `x` is already in a static form then `x` is returned. If
213+
there is no static alternative for `x` then an error is thrown.
214+
211215
## StaticInt(N::Int)
212216

213217
Creates a static integer with value known at compile time. It is a number,

src/ArrayInterface.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@ using IfElse
44
using Requires
55
using LinearAlgebra
66
using SparseArrays
7+
using Base.Cartesian
78

8-
using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray
9+
using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray
10+
11+
@static if VERSION >= v"1.7.0-DEV.421"
12+
using Base: @aggressive_constprop
13+
else
14+
macro aggressive_constprop(ex)
15+
ex
16+
end
17+
end
918

1019
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
1120
parameterless_type(x) = parameterless_type(typeof(x))

src/dimensions.jl

Lines changed: 93 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A
3939
end
4040
out
4141
end
42-
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I}
43-
return _val_to_static(Val(I))
44-
end
42+
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))
4543

4644
"""
4745
to_parent_dims(::Type{T}) -> Bool
@@ -51,7 +49,7 @@ Returns the mapping from child dimensions to parent dimensions.
5149
to_parent_dims(x) = to_parent_dims(typeof(x))
5250
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
5351
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
54-
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I))
52+
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = static(Val(I))
5553
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
5654
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
5755
out = Expr(:tuple)
@@ -79,37 +77,44 @@ function has_dimnames(::Type{T}) where {T}
7977
end
8078
end
8179

80+
# this takes the place of dimension names that aren't defined
81+
const SUnderscore = StaticSymbol(:_)
82+
8283
"""
83-
dimnames(::Type{T}) -> Tuple{Vararg{Symbol}}
84-
dimnames(::Type{T}, d) -> Symbol
84+
dimnames(::Type{T}) -> Tuple{Vararg{StaticSymbol}}
85+
dimnames(::Type{T}, dim) -> StaticSymbol
8586
8687
Return the names of the dimensions for `x`.
8788
"""
8889
@inline dimnames(x) = dimnames(typeof(x))
89-
@inline dimnames(x, i::Integer) = dimnames(typeof(x), i)
90-
@inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d))
91-
@inline function dimnames(::Type{T}) where {T}
92-
if parent_type(T) <: T
93-
return ntuple(i -> :_, Val(ndims(T)))
90+
@inline dimnames(x, dim::Int) = dimnames(typeof(x), dim)
91+
@inline dimnames(x, dim::StaticInt) = dimnames(typeof(x), dim)
92+
@inline function dimnames(::Type{T}, ::StaticInt{dim}) where {T,dim}
93+
if ndims(T) < dim
94+
return SUnderscore
9495
else
95-
return dimnames(parent_type(T))
96+
return getfield(dimnames(T), dim)
9697
end
9798
end
98-
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
99-
return _transpose_dimnames(Val(dimnames(parent_type(T))))
99+
@inline function dimnames(::Type{T}, dim::Int) where {T}
100+
if ndims(T) < dim
101+
return SUnderscore
102+
else
103+
return getfield(dimnames(T), dim)
104+
end
100105
end
101-
# inserting the Val here seems to help inferability; I got a test failure without it.
102-
function _transpose_dimnames(::Val{S}) where {S}
103-
if length(S) == 1
104-
(:_, first(S))
105-
elseif length(S) == 2
106-
(last(S), first(S))
106+
@inline function dimnames(::Type{T}) where {T}
107+
if parent_type(T) <: T
108+
return ntuple(_ -> SUnderscore, Val(ndims(T)))
107109
else
108-
throw("Can't transpose $S of dim $(length(S)).")
110+
return dimnames(parent_type(T))
109111
end
110112
end
111-
@inline _transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
112-
@inline _transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
113+
@inline function dimnames(::Type{T}) where {T<:Union{Adjoint,Transpose}}
114+
_transpose_dimnames(dimnames(parent_type(T)))
115+
end
116+
@inline _transpose_dimnames(x::Tuple{Any,Any}) = (last(x), first(x))
117+
@inline _transpose_dimnames(x::Tuple{Any}) = (SUnderscore, first(x))
113118

114119
@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
115120
return map(i -> dimnames(parent_type(T), i), I)
@@ -123,7 +128,7 @@ end
123128
for i in 1:length(I)
124129
if I[i] > 0
125130
if nl < i
126-
push!(e.args, QuoteNode(:_))
131+
push!(e.args, :(ArrayInterface.SUnderscore))
127132
else
128133
push!(e.args, QuoteNode(L[i]))
129134
end
@@ -132,83 +137,79 @@ end
132137
return e
133138
end
134139

135-
"""
136-
to_dims(x[, d])
140+
_to_int(x::Integer) = Int(x)
141+
_to_int(x::StaticInt) = x
137142

138-
This returns the dimension(s) of `x` corresponding to `d`.
139-
"""
140-
to_dims(x, d) = to_dims(dimnames(x), d)
141-
to_dims(x::Tuple{Vararg{Symbol}}, d::Integer) = Int(d)
142-
to_dims(x::Tuple{Vararg{Symbol}}, d::Colon) = d # `:` is the default for most methods that take `dims`
143-
@inline to_dims(x::Tuple{Vararg{Symbol}}, d::Tuple) = map(i -> to_dims(x, i), d)
144-
@inline function to_dims(x::Tuple{Vararg{Symbol}}, d::Symbol)::Int
145-
i = _sym_to_dim(x, d)
146-
if i === 0
147-
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(x))"))
148-
end
149-
return i
150-
end
151-
Base.@pure function _sym_to_dim(x::Tuple{Vararg{Symbol,N}}, sym::Symbol) where {N}
152-
for i in 1:N
153-
getfield(x, i) === sym && return i
154-
end
155-
return 0
143+
function no_dimname_error(@nospecialize(x), @nospecialize(dim))
144+
throw(ArgumentError("($(repr(dim))) does not correspond to any dimension of ($(x))"))
156145
end
157146

158147
"""
159-
tuple_issubset
148+
to_dims(::Type{T}, dim) -> Integer
160149
161-
A version of `issubset` sepecifically for `Tuple`s of `Symbol`s, that is `@pure`.
162-
This helps it get optimised out of existance. It is less of an abuse of `@pure` than
163-
most of the stuff for making `NamedTuples` work.
150+
This returns the dimension(s) of `x` corresponding to `d`.
164151
"""
165-
Base.@pure function tuple_issubset(
166-
lhs::Tuple{Vararg{Symbol,N}}, rhs::Tuple{Vararg{Symbol,M}}
167-
) where {N,M}
168-
N <= M || return false
169-
for a in lhs
170-
found = false
171-
for b in rhs
172-
found |= a === b
173-
end
174-
found || return false
175-
end
176-
return true
152+
to_dims(x, dim) = to_dims(typeof(x), dim)
153+
to_dims(::Type{T}, dim::Integer) where {T} = _to_int(dim)
154+
to_dims(::Type{T}, dim::Colon) where {T} = dim
155+
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
156+
i = find_first_eq(dim, dimnames(T))
157+
i === nothing && no_dimname_error(T, dim)
158+
return i
159+
end
160+
@inline function to_dims(::Type{T}, dim::Symbol) where {T}
161+
i = find_first_eq(dim, Symbol.(dimnames(T)))
162+
i === nothing && no_dimname_error(T, dim)
163+
return i
177164
end
165+
to_dims(::Type{T}, dims::Tuple) where {T} = map(i -> to_dims(T, i), dims)
178166

179-
"""
180-
order_named_inds(Val(names); kwargs...)
181-
order_named_inds(Val(names), namedtuple)
167+
#=
168+
order_named_inds(names, namedtuple)
169+
order_named_inds(names, subnames, inds)
182170
183171
Returns the tuple of index values for an array with `names`, when indexed by keywords.
184172
Any dimensions not fixed are given as `:`, to make a slice.
185173
An error is thrown if any keywords are used which do not occur in `nda`'s names.
186-
"""
187-
@inline function order_named_inds(val::Val{L}; kwargs...) where {L}
188-
if isempty(kwargs)
189-
return ()
174+
175+
176+
1. parse into static dimnension names and key words.
177+
2. find each dimnames in key words
178+
3. if nothing is found use Colon()
179+
4. if (ndims - ncolon) === nkwargs then all were found, else error
180+
=#
181+
order_named_inds(x::Tuple, ::NamedTuple{(),Tuple{}}) = ()
182+
function order_named_inds(x::Tuple, nd::NamedTuple{L}) where {L}
183+
return order_named_inds(x, static(Val(L)), Tuple(nd))
184+
end
185+
@aggressive_constprop function order_named_inds(
186+
x::Tuple{Vararg{Any,N}},
187+
nd::Tuple,
188+
inds::Tuple
189+
) where {N}
190+
191+
out = eachop(((x, nd, inds), i) -> order_named_inds(x, nd, inds, i), (x, nd, inds), nstatic(Val(N)))
192+
_order_named_inds_check(out, length(nd))
193+
return out
194+
end
195+
function order_named_inds(x::Tuple, nd::Tuple, inds::Tuple, ::StaticInt{dim}) where {dim}
196+
index = find_first_eq(getfield(x, dim), nd)
197+
if index === nothing
198+
return Colon()
190199
else
191-
return order_named_inds(val, kwargs.data)
192-
end
193-
end
194-
@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K}
195-
tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K"))
196-
exs = map(L) do n
197-
if Base.sym_in(n, K)
198-
qn = QuoteNode(n)
199-
:(getfield(ni, $qn))
200-
else
201-
:(Colon())
202-
end
200+
return @inbounds(inds[index])
203201
end
204-
return Expr(:tuple, exs...)
205202
end
206-
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
207-
out = Expr(:curly, :Tuple)
208-
for p in P
209-
push!(out.args, T.parameters[p])
203+
204+
ncolon(x::Tuple{Colon,Vararg}, n::Int) = ncolon(tail(x), n + 1)
205+
ncolon(x::Tuple{Any,Vararg}, n::Int) = ncolon(tail(x), n)
206+
ncolon(x::Tuple{Colon}, n::Int) = n + 1
207+
ncolon(x::Tuple{Any}, n::Int) = n
208+
function _order_named_inds_check(inds::Tuple{Vararg{Any,N}}, nkwargs::Int) where {N}
209+
if (N - ncolon(inds, 0)) !== nkwargs
210+
error("Not all keywords matched dimension names.")
210211
end
211-
Expr(:block, Expr(:meta, :inline), out)
212+
return nothing
212213
end
213214

214215
"""
@@ -226,14 +227,11 @@ function axes_types(::Type{T}) where {T}
226227
return axes_types(parent_type(T))
227228
end
228229
end
229-
function axes_types(::Type{T}) where {T<:Adjoint}
230-
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
230+
function axes_types(::Type{T}) where {T<:MatAdjTrans}
231+
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
231232
end
232-
function axes_types(::Type{T}) where {T<:Transpose}
233-
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
234-
end
235-
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
236-
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
233+
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
234+
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
237235
end
238236
function axes_types(::Type{T}) where {T<:AbstractRange}
239237
if known_length(T) === nothing
@@ -311,8 +309,6 @@ end
311309
Expr(:block, Expr(:meta, :inline), out)
312310
end
313311

314-
315-
316312
"""
317313
size(A)
318314
@@ -330,12 +326,7 @@ julia> ArrayInterface.size(A)
330326
@inline size(A) = Base.size(A)
331327
@inline size(A, d::Integer) = size(A)[Int(d)]
332328
@inline size(A, d) = Base.size(A, to_dims(A, d))
333-
@inline function size(x::LinearAlgebra.Adjoint{T,V}) where {T,V<:AbstractVector{T}}
334-
return (One(), static_length(x))
335-
end
336-
@inline function size(x::LinearAlgebra.Transpose{T,V}) where {T,V<:AbstractVector{T}}
337-
return (One(), static_length(x))
338-
end
329+
@inline size(x::VecAdjTrans) = (One(), static_length(x))
339330

340331
function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
341332
return _size(size(parent(B)), B.indices, map(static_length, B.indices))
@@ -357,9 +348,9 @@ end
357348
Expr(:block, Expr(:meta, :inline), t)
358349
end
359350
@inline size(v::AbstractVector) = (static_length(v),)
360-
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}())
351+
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
361352
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
362-
return permute(size(parent(B)), Val{I1}())
353+
return permute(size(parent(B)), to_parent_dims(B))
363354
end
364355
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]
365356
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N]

src/indexing.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ Changing indexing based on a given argument from `args` should be done through
432432
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args))
433433
@propagate_inbounds function getindex(A; kwargs...)
434434
if has_dimnames(A)
435-
return A[order_named_inds(Val(dimnames(A)); kwargs...)...]
435+
return A[order_named_inds(dimnames(A), kwargs.data)...]
436436
else
437437
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
438438
end
@@ -548,7 +548,7 @@ Store the given values at the given key or index within a collection.
548548
end
549549
@propagate_inbounds function setindex!(A, val; kwargs...)
550550
if has_dimnames(A)
551-
A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val
551+
A[order_named_inds(dimnames(A), kwargs.data)...] = val
552552
else
553553
return unsafe_setindex!(A, val, to_indices(A, ()); kwargs...)
554554
end
@@ -662,3 +662,4 @@ end
662662
) where {N}
663663
return _generate_unsafe_setindex!_body(N)
664664
end
665+

src/ranges.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,3 +607,4 @@ function Base.show(io::IO, r::OptionallyStaticRange)
607607
end
608608
print(io, last(r))
609609
end
610+

0 commit comments

Comments
 (0)