Skip to content

Commit e25f898

Browse files
authored
Use is_forwarding_wrapper trait instead of parent_type(T) <: T (#296)
Using the heuristic parent_type(T) <: T sometimes gives incorrect results and may hide holes in the interface (as I found when making these changes). Some problems I found: * is_increasing(stride_rank(A)) didn't account for A having zero dimensions * is_dense(A) didn't know how to handle when dense_dims(A) returned nothing. * contiguous_axis(::Type{<:ReshapedArray}) would sometimes return nothing (meaning the contiguous axis is unknown) instead of StaticInt(-1) when we actually knew there couldn't be a contiguous axis because it didn't exist in the parent array. * (known)_dimnames(A) was often returning incorrect results for ReinterpretArray and ReshapedArray.
1 parent fb4a44b commit e25f898

File tree

10 files changed

+216
-113
lines changed

10 files changed

+216
-113
lines changed

docs/src/index.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@ Creating an array type with unique behavior in Julia is often accomplished by cr
3434
This allows the new array type to inherit functionality by redirecting methods to the parent array (e.g., `Base.size(x::Wrapper) = size(parent(x))`).
3535
Generic design limits the need to define an excessive number of methods like this.
3636
However, methods used to describe a type's traits often need to be explicitly defined for each trait method.
37-
`ArrayInterface` assists with this by providing information about the parent type using [`ArrayInterface.parent_type`](@ref).
38-
By default `ArrayInterface.parent_type(::Type{T})` returns `T` (analogous to `Base.parent(x) = x`).
39-
If any type other than `T` is returned we assume `T` wraps a parent structure, so methods know to unwrap instances of `T`.
40-
It is also assumed that if `T` has a parent type `Base.parent` is defined.
37+
If the the underlying data and access to it are unchanged by it's wrapper the [`ArrayInterface.is_forwarding_wrapper`](@ref) trait can signal to other trait methods to access its parent data structure.
38+
Supporting this for a new type only requires defines these methods:
39+
40+
```julia
41+
ArrayInterface.is_forwarding_wrapper(::Type{<:NewType}) = true
42+
ArrayInterface.parent_type(::Type{<:NewType}) = NewTypeParent
43+
Base.parent(x::NewType) = x.parent
44+
```
4145

4246
For those authoring new trait methods, this may change the default definition from `has_trait(::Type{T}) where {T} = false`, to:
4347
```julia
4448
function has_trait(::Type{T}) where {T}
45-
if parent_type(T) <:T
49+
if is_forwarding_wrapper(T)
4650
return false
4751
else
4852
return has_trait(parent_type(T))

src/ArrayInterface.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm,
8-
fast_scalar_indexing, parameterless_type, ndims_index, is_splat_index
8+
fast_scalar_indexing, parameterless_type, ndims_index, is_splat_index, is_forwarding_wrapper
99

1010
# ArrayIndex subtypes and methods
1111
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
@@ -34,6 +34,8 @@ using LinearAlgebra
3434

3535
import Compat
3636

37+
n_of_x(::StaticInt{N}, x::X) where {N,X} = ntuple(Compat.Returns(x), Val{N}())
38+
3739
@generated function merge_tuple_type(::Type{X}, ::Type{Y}) where {X<:Tuple,Y<:Tuple}
3840
Tuple{X.parameters...,Y.parameters...}
3941
end
@@ -48,7 +50,7 @@ Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
4850
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))
4951

5052
function Base.axes(A::AbstractArray2)
51-
!(parent_type(A) <: typeof(A)) && return ArrayInterface.axes(parent(A))
53+
is_forwarding_wrapper(A) && return ArrayInterface.axes(parent(A))
5254
throw(ArgumentError("Subtypes of `AbstractArray2` must define an axes method"))
5355
end
5456
function Base.axes(A::AbstractArray2, dim::Union{Symbol,StaticSymbol})
@@ -62,11 +64,7 @@ end
6264
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim))
6365

6466
function Base.IndexStyle(::Type{T}) where {T<:AbstractArray2}
65-
if parent_type(T) <: T
66-
return IndexCartesian()
67-
else
68-
return IndexStyle(parent_type(T))
69-
end
67+
is_forwarding_wrapper(T) ? IndexStyle(parent_type(T)) : IndexCartesian()
7068
end
7169

7270
function Base.length(A::AbstractArray2)

src/axes.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ end
2323
axes_types(x) = axes_types(typeof(x))
2424
axes_types(::Type{T}) where {T<:Array} = NTuple{ndims(T),OneTo{Int}}
2525
@inline function axes_types(::Type{T}) where {T}
26-
if parent_type(T) <: T
27-
return NTuple{ndims(T),OptionallyStaticUnitRange{One,Int}}
28-
else
26+
if is_forwarding_wrapper(T)
2927
return axes_types(parent_type(T))
28+
else
29+
return NTuple{ndims(T),OptionallyStaticUnitRange{One,Int}}
3030
end
3131
end
3232
axes_types(::Type{<:LinearIndices{N,R}}) where {N,R} = R
@@ -212,7 +212,7 @@ end
212212

213213
Base.keys(x::LazyAxis) = keys(parent(x))
214214

215-
Base.IndexStyle(::Type{<:LazyAxis}) = IndexStyle(parent_type(T))
215+
Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T))
216216

217217
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))
218218

src/dimensions.jl

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
3535
end
3636
end
3737
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
38+
is_increasing(::Tuple{}) = True()
3839

3940
"""
4041
from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
@@ -164,13 +165,53 @@ have a name.
164165
"""
165166
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim))
166167
known_dimnames(x) = known_dimnames(typeof(x))
167-
known_dimnames(::Type{T}) where {T} = _known_dimnames(T, parent_type(T))
168-
_known_dimnames(::Type{T}, ::Type{T}) where {T} = _unknown_dimnames(Base.IteratorSize(T))
169-
_unknown_dimnames(::Base.HasShape{N}) where {N} = ntuple(Compat.Returns(:_), Val(N))
168+
function known_dimnames(@nospecialize T::Type{<:VecAdjTrans})
169+
(:_, getfield(known_dimnames(parent_type(T)), 1))
170+
end
171+
function known_dimnames(@nospecialize T::Type{<:Union{MatAdjTrans,PermutedDimsArray,SubArray}})
172+
eachop(_inbounds_known_dimname, to_parent_dims(T), known_dimnames(parent_type(T)))
173+
end
174+
function known_dimnames(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
175+
pnames = known_dimnames(A)
176+
if IsReshaped
177+
if sizeof(S) === sizeof(T)
178+
return pnames
179+
elseif sizeof(S) > sizeof(T)
180+
return (:_, pnames...)
181+
else
182+
return tail(pnames)
183+
end
184+
else
185+
return pnames
186+
end
187+
end
188+
189+
@inline function known_dimnames(@nospecialize T::Type{<:Base.ReshapedArray})
190+
if ndims(T) === ndims(parent_type(T))
191+
return known_dimnames(parent_type(T))
192+
elseif ndims(T) > ndims(parent_type(T))
193+
return (known_dimnames(parent_type(T))..., n_of_x(StaticInt(ndims(T) - ndims(parent_type(T))), :_)...)
194+
else
195+
return n_of_x(StaticInt(ndims(T)), :_)
196+
end
197+
end
198+
@inline function known_dimnames(::Type{T}) where {T}
199+
if is_forwarding_wrapper(T)
200+
return known_dimnames(parent_type(T))
201+
else
202+
return _unknown_dimnames(Base.IteratorSize(T))
203+
end
204+
end
205+
206+
_unknown_dimnames(::Base.HasShape{N}) where {N} = n_of_x(StaticInt(N), :_)
170207
_unknown_dimnames(::Any) = (:_,)
208+
209+
#=
210+
_known_dimnames(::Type{T}, ::Type{T}) where {T} = _unknown_dimnames(Base.IteratorSize(T))
171211
function _known_dimnames(::Type{C}, ::Type{P}) where {C,P}
172212
eachop(_inbounds_known_dimname, to_parent_dims(C), known_dimnames(P))
173213
end
214+
=#
174215
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
175216
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
176217
(dim > N || dim < 1) && return :_
@@ -186,11 +227,41 @@ Return the names of the dimensions for `x`. `:_` is used to indicate a dimension
186227
have a name.
187228
"""
188229
@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim))
189-
@inline dimnames(x) = _dimnames(has_parent(x), x)
190-
@inline function _dimnames(::True, x)
191-
eachop(_inbounds_dimname, to_parent_dims(x), dimnames(parent(x)))
230+
@inline function dimnames(x::Union{MatAdjTrans,PermutedDimsArray,SubArray})
231+
eachop(_inbounds_known_dimname, to_parent_dims(x), dimnames(parent(x)))
232+
end
233+
dimnames(x::VecAdjTrans) = (static(:_), getfield(dimnames(parent(x)), 1))
234+
@inline function dimnames(x::ReinterpretArray{T,N,S,A,IsReshaped}) where {T,N,S,A,IsReshaped}
235+
pnames = dimnames(parent(x))
236+
if IsReshaped
237+
if sizeof(S) === sizeof(T)
238+
return pnames
239+
elseif sizeof(S) > sizeof(T)
240+
return (static(:_), pnames...)
241+
else
242+
return tail(pnames)
243+
end
244+
else
245+
return pnames
246+
end
247+
end
248+
@inline function dimnames(x::Base.ReshapedArray)
249+
p = parent(x)
250+
if ndims(x) === ndims(p)
251+
return dimnames(p)
252+
elseif ndims(x) > ndims(p)
253+
return (dimnames(p)..., n_of_x(StaticInt(ndims(x) - ndims(p)), static(:_))...)
254+
else
255+
return n_of_x(StaticInt(ndims(x)), static(:_))
256+
end
257+
end
258+
@inline function dimnames(x::X) where {X}
259+
if is_forwarding_wrapper(X)
260+
return dimnames(parent(x))
261+
else
262+
return n_of_x(StaticInt(ndims(x)), static(:_))
263+
end
192264
end
193-
_dimnames(::False, x) = ntuple(_->static(:_), Val(ndims(x)))
194265
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
195266
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
196267
# for calls such as `dimnames(view(x, :, 1, :))`

src/indexing.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,14 @@ end
318318

319319
## unsafe_getindex ##
320320
function unsafe_getindex(a::A) where {A}
321-
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,)))
321+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A,)))
322322
unsafe_getindex(parent(a))
323323
end
324324

325325
# TODO Need to manage index transformations between nested layers of arrays
326326
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
327327
if IndexStyle(A) === IndexLinear()
328-
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
328+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
329329
return unsafe_getindex(parent(a), i)
330330
else
331331
return unsafe_getindex(a, _to_cartesian(a, i)...)
@@ -335,7 +335,7 @@ function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where
335335
if IndexStyle(A) === IndexLinear()
336336
return unsafe_getindex(a, _to_linear(a, (i, ii...)))
337337
else
338-
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
338+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
339339
return unsafe_getindex(parent(a), i, ii...)
340340
end
341341
end
@@ -415,13 +415,13 @@ end
415415

416416
## unsafe_setindex! ##
417417
function unsafe_setindex!(a::A, v) where {A}
418-
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v)))
418+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v)))
419419
return unsafe_setindex!(parent(a), v)
420420
end
421421
# TODO Need to manage index transformations between nested layers of arrays
422422
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
423423
if IndexStyle(A) === IndexLinear()
424-
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
424+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i)))
425425
return unsafe_setindex!(parent(a), v, i)
426426
else
427427
return unsafe_setindex!(a, v, _to_cartesian(a, i)...)
@@ -431,7 +431,7 @@ function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) wh
431431
if IndexStyle(A) === IndexLinear()
432432
return unsafe_setindex!(a, v, _to_linear(a, (i, ii...)))
433433
else
434-
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i, ii...)))
434+
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i, ii...)))
435435
return unsafe_setindex!(parent(a), v, i, ii...)
436436
end
437437
end

src/size.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ julia> ArrayInterface.size(A)
1616
(static(3), static(4))
1717
```
1818
"""
19-
size(a::A) where {A} = _maybe_size(Base.IteratorSize(A), a)
19+
@inline function size(a::A) where {A}
20+
if is_forwarding_wrapper(A)
21+
return size(parent(a))
22+
else
23+
return _maybe_size(Base.IteratorSize(A), a)
24+
end
25+
end
2026
size(a::Base.Broadcast.Broadcasted) = map(length, axes(a))
2127

2228
_maybe_size(::Base.HasShape{N}, a::A) where {N,A} = map(length, axes(a))
@@ -56,15 +62,15 @@ end
5662
size(a, dim) = size(a, to_dims(a, dim))
5763
size(a::Array, dim::CanonicalInt) = Base.arraysize(a, convert(Int, dim))
5864
function size(a::A, dim::CanonicalInt) where {A}
59-
if parent_type(A) <: A
65+
if is_forwarding_wrapper(A)
66+
return size(parent(a), dim)
67+
else
6068
len = known_size(A, dim)
6169
if len === nothing
6270
return Int(length(axes(a, dim)))
6371
else
6472
return StaticInt(len)
6573
end
66-
else
67-
return size(a)[dim]
6874
end
6975
end
7076
function size(A::SubArray, dim::CanonicalInt)
@@ -86,6 +92,17 @@ compile time. If a dimension does not have a known size along a dimension then `
8692
returned in its position.
8793
"""
8894
known_size(x) = known_size(typeof(x))
95+
@inline function known_size(::Type{T}) where {T}
96+
if is_forwarding_wrapper(T)
97+
return known_size(parent_type(T))
98+
else
99+
return _maybe_known_size(Base.IteratorSize(T), T)
100+
end
101+
end
102+
function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
103+
eachop(_known_size, nstatic(Val(N)), axes_types(T))
104+
end
105+
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
89106
function known_size(::Type{T}) where {T<:AbstractRange}
90107
(_range_length(known_first(T), known_step(T), known_last(T)),)
91108
end
@@ -109,12 +126,6 @@ end
109126
dynamic(reduce_tup(_promote_shape, eachop(_unzip_size, nstatic(Val(known_length(T))), T)))
110127
end
111128
_unzip_size(::Type{T}, n::StaticInt{N}) where {T,N} = known_size(field_type(T, n))
112-
113-
known_size(::Type{T}) where {T} = _maybe_known_size(Base.IteratorSize(T), T)
114-
function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
115-
eachop(_known_size, nstatic(Val(N)), axes_types(T))
116-
end
117-
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
118129
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, dim))
119130
@inline known_size(x, dim) = known_size(typeof(x), dim)
120131
@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim))

0 commit comments

Comments
 (0)