Skip to content

Commit 75c9a32

Browse files
authored
Dynamic dimnames (#238)
* Enable optionally static/dynamic dimnames
1 parent b1ff0a2 commit 75c9a32

File tree

7 files changed

+134
-106
lines changed

7 files changed

+134
-106
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "4.0.2"
3+
version = "4.0.3"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ ArrayInterface.ismutable
2323
ArrayInterface.issingular
2424
ArrayInterface.isstructured
2525
ArrayInterface.is_splat_index
26+
ArrayInterface.known_dimnames
2627
ArrayInterface.known_first
2728
ArrayInterface.known_last
2829
ArrayInterface.known_length

docs/src/index.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,22 @@ New types defining dimension names can do something similar to:
100100
using Static
101101
using ArrayInterface
102102

103-
struct NewType{dnames} end # where dnames::Tuple{Vararg{Symbol}}
103+
struct StaticDimnames{dnames} end # where dnames::Tuple{Vararg{Symbol}}
104+
105+
ArrayInterface.known_dimnames(::Type{StaticDimnames{dnames}}) where {dnames} = dnames
106+
ArrayInterface.dimnames(::StaticDimnames{dnames}) where {dnames} = static(dnames)
107+
108+
struct DynamicDimnames{N}
109+
dimnames::NTuple{N,Symbol}
110+
end
111+
ArrayInterface.known_dimnames(::Type{DynamicDimnames{N}}) where {N} = ntuple(_-> missing, Val(N))
112+
ArrayInterface.dimnames(x::DynamicDimnames) = getfield(x, :dimnames)
104113

105-
ArrayInterface.dimnames(::Type{NewType{dnames}}) = static(dnames)
106114
```
107115

116+
Notice that `DynamicDimnames` returns `missing` instead of a symbol for each dimension.
117+
This indicates dimension names are present for `DynamicDimnames` but that information is missing at compile time.
118+
108119
Dimension names should be appropriately propagated between nested arrays using `ArrayInterface.to_parent_dims`.
109120
This allows types such as `SubArray` and `PermutedDimsArray` to work with named dimensions.
110121
Similarly, other methods that return information corresponding to dimensions (e.g., `ArrayInterfce.size`, `ArrayInterface.axes`) use `to_parent_dims` to appropriately propagate parent information.

src/axes.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
66
Returns the type of each axis for the `T`, or the type of of the axis along dimension `dim`.
77
"""
8-
axes_types(x, dim) = axes_types(typeof(x), dim)
9-
@inline axes_types(::Type{T}, dim) where {T} = axes_types(T, to_dims(T, dim))
10-
@inline function axes_types(::Type{T}, dim::StaticInt{D}) where {T,D}
11-
if D > ndims(T)
8+
@inline axes_types(x, dim) = axes_types(x, to_dims(x, dim))
9+
@inline function axes_types(x, dim::StaticInt{D}) where {D}
10+
if D > ndims(x)
1211
return SOneTo{1}
1312
else
14-
return field_type(axes_types(T), dim)
13+
return field_type(axes_types(x), dim)
1514
end
1615
end
17-
@inline function axes_types(::Type{T}, dim::Int) where {T}
18-
if dim > ndims(T)
16+
@inline function axes_types(x, dim::Int)
17+
if dim > ndims(x)
1918
return SOneTo{1}
2019
else
21-
return axes_types(T).parameters[dim]
20+
return axes_types(x).parameters[dim]
2221
end
2322
end
2423
axes_types(x) = axes_types(typeof(x))
@@ -288,3 +287,4 @@ lazy_axes(x::CartesianIndices) = axes(x)
288287
@inline lazy_axes(x::MatAdjTrans) = reverse(lazy_axes(parent(x)))
289288
@inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x))))
290289
@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x))
290+

src/dimensions.jl

Lines changed: 75 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -147,68 +147,76 @@ function to_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim}
147147
end
148148
end
149149

150+
_nunderscore(::Val{N}) where {N} = ntuple(Compat.Returns(:_), Val(N))
151+
150152
"""
151-
has_dimnames(::Type{T}) -> Bool
153+
has_dimnames(::Type{T}) -> StaticBool
152154
153155
Returns `static(true)` if `x` has on or more named dimensions. If all dimensions correspond
154156
to `static(:_)`, then `static(false)` is returned.
155157
"""
156-
has_dimnames(x) = has_dimnames(typeof(x))
157-
@inline has_dimnames(::Type{T}) where {T} = _has_dimnames(dimnames(T))
158-
_has_dimnames(::Tuple{Vararg{StaticSymbol{:_}}}) = static(false)
159-
_has_dimnames(::Tuple) = static(true)
160-
161-
# this takes the place of dimension names that aren't defined
162-
const SUnderscore = StaticSymbol(:_)
158+
Compat.@constprop :aggressive has_dimnames(x) = static(_is_named(known_dimnames(x)))
159+
_is_named(x::NTuple{N,Symbol}) where {N} = x !== _nunderscore(Val(N))
160+
_is_named(::Any) = true
163161

164162
"""
165-
dimnames(::Type{T}) -> Tuple{Vararg{StaticSymbol}}
166-
dimnames(::Type{T}, dim) -> StaticSymbol
163+
known_dimnames(::Type{T}) -> Tuple{Vararg{Union{Symbol,Missing}}}
164+
known_dimnames(::Type{T}, dim::Union{Int,StaticInt}) -> Union{Symbol,Missing}
167165
168-
Return the names of the dimensions for `x`. `static(:_)` is used to indicate a dimension
169-
does not have a name.
166+
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
167+
have a name.
170168
"""
171-
@inline dimnames(x) = dimnames(typeof(x))
172-
@inline dimnames(::Type{T}) where {T} = _dimnames(has_parent(T), T)
173-
_dimnames(::False, ::Type{T}) where {T} = ntuple(_->static(:_), Val(ndims(T)))
174-
@inline function _dimnames(::True, ::Type{T}) where {T}
175-
eachop(_perm_dimnames, to_parent_dims(T), dimnames(parent_type(T)))
169+
@inline known_dimnames(x, dim::Integer) = _known_dimname(known_dimnames(x), canonicalize(dim))
170+
known_dimnames(x) = known_dimnames(typeof(x))
171+
known_dimnames(::Type{T}) where {T} = _known_dimnames(T, parent_type(T))
172+
_known_dimnames(::Type{T}, ::Type{T}) where {T} = _unknown_dimnames(Base.IteratorSize(T))
173+
_unknown_dimnames(::Base.HasShape{N}) where {N} = _nunderscore(Val(N))
174+
_unknown_dimnames(::Any) = (:_,)
175+
function _known_dimnames(::Type{C}, ::Type{P}) where {C,P}
176+
eachop(_inbounds_known_dimname, to_parent_dims(C), known_dimnames(P))
177+
end
178+
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
179+
@boundscheck (dim > N || dim < 1) && return :_
180+
return @inbounds(x[dim])
176181
end
182+
@inline _inbounds_known_dimname(x, dim) = @inbounds(_known_dimname(x, dim))
177183

178-
@inline dimnames(x, dim) = dimnames(typeof(x), dim)
179-
@inline dimnames(::Type{T}, dim::Integer) where {T} = _perm_dimnames(dimnames(T), dim)
180-
function _perm_dimnames(dnames::Tuple{Vararg{StaticSymbol,N}}, dim) where {N}
181-
if dim > N
182-
return static(:_)
183-
else
184-
return @inbounds(dnames[dim])
185-
end
184+
"""
185+
dimnames(x) -> Tuple{Vararg{Union{Symbol,StaticSymbol}}}
186+
dimnames(x, dim::Union{Int,StaticInt}) -> Union{Symbol,StaticSymbol}
187+
188+
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
189+
have a name.
190+
"""
191+
@inline dimnames(x, dim::Integer) = _dimname(dimnames(x), canonicalize(dim))
192+
@inline dimnames(x) = _dimnames(has_parent(x), x)
193+
@inline function _dimnames(::True, x)
194+
eachop(_inbounds_dimname, to_parent_dims(x), dimnames(parent(x)))
195+
end
196+
_dimnames(::False, x) = ntuple(_->static(:_), Val(ndims(x)))
197+
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
198+
@boundscheck (dim > N || dim < 1) && return static(:_)
199+
return @inbounds(x[dim])
186200
end
201+
@inline _inbounds_dimname(x, dim) = @inbounds(_dimname(x, dim))
187202

188203
"""
189-
to_dims(::Type{T}, dim) -> Union{Int,StaticInt}
204+
to_dims(x, dim) -> Union{Int,StaticInt}
190205
191-
This returns the dimension(s) of `x` corresponding to `d`.
206+
This returns the dimension(s) of `x` corresponding to `dim`.
192207
"""
193-
to_dims(x, dim) = to_dims(typeof(x), dim)
194-
to_dims(::Type{T}, dim::StaticInt) where {T} = dim
195-
to_dims(::Type{T}, dim::Integer) where {T} = Int(dim)
196-
to_dims(::Type{T}, dim::Colon) where {T} = dim
197-
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
198-
i = find_first_eq(dim, dimnames(T))
199-
if i === nothing
200-
throw_dim_error(T, dim)
201-
end
202-
return i
208+
to_dims(x, dim::Colon) = dim
209+
to_dims(x, dim::Integer) = canonicalize(dim)
210+
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
211+
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
212+
eachop(_to_dims, nstatic(Val(N)), dimnames(x), dims)
203213
end
204-
Compat.@constprop :aggressive function to_dims(::Type{T}, dim::Symbol) where {T}
205-
i = find_first_eq(dim, map(Symbol, dimnames(T)))
206-
if i === nothing
207-
throw_dim_error(T, dim)
208-
end
214+
@inline _to_dims(x::Tuple, d::Tuple, n::StaticInt{N}) where {N} = _to_dim(x, getfield(d, N))
215+
@inline function _to_dim(x::Tuple, d::Union{Symbol,StaticSymbol})
216+
i = find_first_eq(d, x)
217+
i === nothing && throw(DimensionMismatch("dimension name $(d) not found"))
209218
return i
210219
end
211-
to_dims(::Type{T}, dims::Tuple) where {T} = map(i -> to_dims(T, i), dims)
212220

213221
#=
214222
order_named_inds(names, namedtuple)
@@ -224,37 +232,31 @@ An error is thrown if any keywords are used which do not occur in `nda`'s names.
224232
3. if missing is found use Colon()
225233
4. if (ndims - ncolon) === nkwargs then all were found, else error
226234
=#
227-
order_named_inds(x::Tuple, ::NamedTuple{(),Tuple{}}) = ()
228-
function order_named_inds(x::Tuple, nd::NamedTuple{L}) where {L}
229-
return order_named_inds(x, static(Val(L)), Tuple(nd))
230-
end
231-
Compat.@constprop :aggressive function order_named_inds(
232-
x::Tuple{Vararg{Any,N}},
233-
nd::Tuple,
234-
inds::Tuple
235-
) where {N}
236-
237-
out = eachop(order_named_inds, nstatic(Val(N)), x, nd, inds)
238-
_order_named_inds_check(out, length(nd))
239-
return out
240-
end
241-
function order_named_inds(x::Tuple, nd::Tuple, inds::Tuple, ::StaticInt{dim}) where {dim}
242-
index = find_first_eq(getfield(x, dim), nd)
243-
if index === nothing
244-
return Colon()
235+
@generated function find_all_dimnames(x::Tuple{Vararg{Any,ND}}, nd::Tuple{Vararg{Any,NI}}, inds::Tuple, default) where {ND,NI}
236+
if NI === 0
237+
return :(())
245238
else
246-
return @inbounds(inds[index])
247-
end
248-
end
249-
250-
ncolon(x::Tuple{Colon,Vararg}, n::Int) = ncolon(tail(x), n + 1)
251-
ncolon(x::Tuple{Any,Vararg}, n::Int) = ncolon(tail(x), n)
252-
ncolon(x::Tuple{Colon}, n::Int) = n + 1
253-
ncolon(x::Tuple{Any}, n::Int) = n
254-
function _order_named_inds_check(inds::Tuple{Vararg{Any,N}}, nkwargs::Int) where {N}
255-
if (N - ncolon(inds, 0)) !== nkwargs
256-
error("Not all keywords matched dimension names.")
239+
out = Expr(:block, Expr(:(=), :names_found, 0))
240+
t = Expr(:tuple)
241+
for i in 1:ND
242+
index_i = Symbol(:index_, i)
243+
val_i = Symbol(:val_, i)
244+
push!(t.args, val_i)
245+
push!(out.args, quote
246+
$index_i = find_first_eq(getfield(x, $i), nd)
247+
if $index_i === nothing
248+
$val_i = default
249+
else
250+
$val_i = @inbounds(inds[$index_i])
251+
names_found += 1
252+
end
253+
end)
254+
end
255+
return quote
256+
$out
257+
@boundscheck names_found === $NI || error("Not all keywords matched dimension names.")
258+
return $t
259+
end
257260
end
258-
return missing
259261
end
260262

src/indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ function getindex(A, args...)
307307
@boundscheck checkbounds(A, inds...)
308308
unsafe_getindex(A, inds...)
309309
end
310-
function getindex(A; kwargs...)
311-
inds = to_indices(A, order_named_inds(dimnames(A), values(kwargs)))
310+
@propagate_inbounds function getindex(A; kwargs...)
311+
inds = to_indices(A, find_all_dimnames(dimnames(A), static(keys(kwargs)), Tuple(values(kwargs)), :))
312312
@boundscheck checkbounds(A, inds...)
313313
unsafe_getindex(A, inds...)
314314
end
@@ -407,7 +407,7 @@ Store the given values at the given key or index within a collection.
407407
end
408408
@propagate_inbounds function setindex!(A, val; kwargs...)
409409
can_setindex(A) || error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
410-
inds = to_indices(A, order_named_inds(dimnames(A), values(kwargs)))
410+
inds = to_indices(A, find_all_dimnames(dimnames(A), static(keys(kwargs)), Tuple(values(kwargs)), :))
411411
@boundscheck checkbounds(A, inds...)
412412
unsafe_setindex!(A, val, inds...)
413413
end

test/dimensions.jl

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
### define wrapper with dimnames
66
###
77

8-
struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractArray2{T,N}
8+
struct NamedDimsWrapper{D,T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractArray2{T,N}
9+
dimnames::D
910
parent::P
10-
NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p)
11+
NamedDimsWrapper(d::D, p::P) where {D,P} = new{D,eltype(P),ndims(p),P}(d, p)
1112
end
13+
Base.parent(x::NamedDimsWrapper) = getfield(x, :parent)
1214
ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P
13-
ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = static(L)
15+
ArrayInterface.dimnames(x::NamedDimsWrapper) = getfield(x, :dimnames)
16+
function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}}
17+
ArrayInterface.Static.known(L)
18+
end
19+
1420
Base.parent(x::NamedDimsWrapper) = x.parent
1521

1622
@testset "dimension permutations" begin
@@ -70,29 +76,31 @@ end
7076
n1 = (static(:x),)
7177
n2 = (n1..., static(:y))
7278
n3 = (n2..., static(:z))
73-
@test @inferred(ArrayInterface.order_named_inds(n1, NamedTuple{(),Tuple{}}(()) )) == ()
74-
@test @inferred(ArrayInterface.order_named_inds(n1, (x=2,))) == (2,)
75-
@test @inferred(ArrayInterface.order_named_inds(n2, (x=2,))) == (2, :)
76-
@test @inferred(ArrayInterface.order_named_inds(n2, (y=2,))) == (:, 2)
77-
@test @inferred(ArrayInterface.order_named_inds(n2, (y=20, x=30))) == (30, 20)
78-
@test @inferred(ArrayInterface.order_named_inds(n2, (x=30, y=20))) == (30, 20)
79-
@test @inferred(ArrayInterface.order_named_inds(n3, (x=30, y=20))) == (30, 20, :)
80-
81-
@test_throws ErrorException ArrayInterface.order_named_inds(n2, (x=30, y=20, z=40))
79+
@test @inferred(ArrayInterface.find_all_dimnames(n1, (), (), :)) == ()
80+
@test @inferred(ArrayInterface.find_all_dimnames(n1, (static(:x),), (2,), :)) == (2,)
81+
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:x),), (2,), :)) == (2,:)
82+
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:y),), (2,), :)) == (:, 2)
83+
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:y), static(:x)), (20, 30), :)) == (30, 20)
84+
@test @inferred(ArrayInterface.find_all_dimnames(n2, (static(:x), static(:y)), (30, 20), :)) == (30, 20)
85+
@test @inferred(ArrayInterface.find_all_dimnames(n3, (static(:x), static(:y)), (30, 20), :)) == (30, 20, :)
86+
87+
@test_throws ErrorException ArrayInterface.find_all_dimnames(n2, (static(:x), static(:y), static(:z)), (30, 20, 40), :)
8288
end
8389

84-
8590
@testset "dimnames" begin
8691
d = (static(:x), static(:y))
87-
x = NamedDimsWrapper{d}(ones(2,2));
88-
y = NamedDimsWrapper{(:x,)}(ones(2));
92+
x = NamedDimsWrapper(d, ones(2,2));
93+
y = NamedDimsWrapper((static(:x),), ones(2));
94+
z = NamedDimsWrapper((:x, static(:y)), ones(2));
8995
dnums = ntuple(+, length(d))
9096
@test @inferred(ArrayInterface.has_dimnames(x)) == true
97+
@test @inferred(ArrayInterface.has_dimnames(z)) == true
9198
@test @inferred(ArrayInterface.has_dimnames(ones(2,2))) == false
9299
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) == false
93100
@test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true
94101
@test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true
95102
@test @inferred(dimnames(x)) === d
103+
@test @inferred(ArrayInterface.dimnames(z)) === (:x, static(:y))
96104
@test @inferred(dimnames(parent(x))) === (static(:_), static(:_))
97105
@test @inferred(dimnames(x')) === reverse(d)
98106
@test @inferred(dimnames(y')) === (static(:_), static(:x))
@@ -103,11 +111,14 @@ end
103111
@test @inferred(dimnames(view(x, :, 1, :))) === (static(:x), static(:_))
104112
@test @inferred(dimnames(x, ArrayInterface.One())) === static(:x)
105113
@test @inferred(dimnames(parent(x), ArrayInterface.One())) === static(:_)
114+
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10))) === (:_,)
115+
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10), static(1))) === :_
116+
@test @inferred(ArrayInterface.known_dimnames(z)) === (missing, :y)
106117
end
107118

108119
@testset "to_dims" begin
109-
x = NamedDimsWrapper{(:x, :y)}(ones(2,2));
110-
y = NamedDimsWrapper{(:x, :y, :a, :b, :c, :d)}(ones(6));
120+
x = NamedDimsWrapper(static((:x, :y)), ones(2,2));
121+
y = NamedDimsWrapper(static((:x, :y, :a, :b, :c, :d)), ones(6));
111122

112123
@test @inferred(ArrayInterface.to_dims(x, :)) == Colon()
113124
@test @inferred(ArrayInterface.to_dims(x, 1)) == 1
@@ -130,8 +141,8 @@ end
130141

131142
@testset "methods accepting dimnames" begin
132143
d = (static(:x), static(:y))
133-
x = NamedDimsWrapper{d}(ones(2,2));
134-
y = NamedDimsWrapper{(:x,)}(ones(2));
144+
x = NamedDimsWrapper(d, ones(2,2));
145+
y = NamedDimsWrapper((static(:x),), ones(2));
135146
@test @inferred(size(x, first(d))) == size(parent(x), 1)
136147
@test @inferred(ArrayInterface.size(y')) == (1, size(parent(x), 1))
137148
@test @inferred(axes(x, first(d))) == axes(parent(x), 1)
@@ -144,6 +155,9 @@ end
144155

145156
x[x = 1] = [2, 3]
146157
@test @inferred(getindex(x, x = 1)) == [2, 3]
158+
y = NamedDimsWrapper((:x, static(:y)), ones(2, 2));
159+
# FIXME this doesn't correctly infer the output because it can't infer
160+
@test getindex(y, x = 1) == [1, 1]
147161
end
148162

149163
end

0 commit comments

Comments
 (0)