Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ end
Returns the parent array that type `T` wraps.
"""
parent_type(x) = parent_type(typeof(x))
parent_type(::Type{Symmetric{T,S}}) where {T,S} = S
parent_type(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = fieldtype(T, :data)
parent_type(::Type{<:AbstractTriangular{T,S}}) where {T,S} = S
parent_type(@nospecialize T::Type{<:PermutedDimsArray}) = fieldtype(T, :parent)
parent_type(@nospecialize T::Type{<:Adjoint}) = fieldtype(T, :parent)
Expand Down Expand Up @@ -667,6 +667,16 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int)
end
end

"""
IndexLabel(label)

A type that clearly communicates to internal methods to lookup the index corresponding to
for `label`.
"""
struct IndexLabel{L} <: ArrayIndex{1}
label::L
end

_cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i)
_cartesian_index(::Any) = nothing

Expand Down Expand Up @@ -766,6 +776,7 @@ julia> ArrayInterfaceCore.ndims_index([CartesianIndex(1, 2), CartesianIndex(1, 3
ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N
# preserve CartesianIndices{0} as they consume a dimension.
ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1
ndims_index(@nospecialize T::Type{<:Union{Number,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 1
ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T)
ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T))
ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask))
Expand Down Expand Up @@ -793,7 +804,7 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2)
ndims_shape(T::DataType) = ndims_index(T)
ndims_shape(::Type{Colon}) = 1
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 0
ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
ndims_shape(x) = ndims_shape(typeof(x))
Expand Down
3 changes: 2 additions & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
issingular, isstructured, matrix_colors, restructure, lu_instance,
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, defines_strides,
stride_preserving_index

# ArrayIndex subtypes and methods
Expand Down Expand Up @@ -35,6 +35,7 @@ using Base.Iterators: Pairs
using LinearAlgebra

import Compat
using Compat: Returns

_add1(@nospecialize x) = x + oneunit(x)
_sub1(@nospecialize x) = x - oneunit(x)
Expand Down
113 changes: 111 additions & 2 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N
end
end


# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases
# breaking changes for this adopting this method to situations where they clearly benefit
# from the propagation of static axes. This creates the somewhat awkward situation of
Expand All @@ -113,7 +112,7 @@ axes(A::ReshapedArray) = Base.axes(A)
@inline function axes(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(axes(parent(x))), to_parent_dims(x))
end
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
axes(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes(parent(A)), 1))

@inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x))))
@inline _sub_axes(x::SubArray, axis::SOneTo) = axis
Expand Down Expand Up @@ -248,3 +247,113 @@ lazy_axes(x::AbstractRange, ::StaticInt{1}) = Base.axes1(x)
lazy_axes(x, ::Colon) = LazyAxis{:}(x)
lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x)
@inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims)

"""
has_index_labels(x) -> Bool

Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) returns a tuple of
`nothing`, this will be `false`.

See also: [`index_labels`](@ref)
"""
has_index_labels(x) = _any_labels(index_labels(x))
function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian})
has_index_labels(parent(x))
end
function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
if has_index_labels(parent(x))
true
else
size1 = div(sizeof(S), sizeof(T))
size1 > 1 && size1 === fieldcount(S)
end
end
function has_index_labels(x::SubArray)
if has_index_labels(parent(x))
return true
else
inds = x.indices
for i in 1:nfields(inds)
has_index_labels(getfield(inds, i)) && return true
end
return false
end
end
_any_labels(@nospecialize labels::Tuple{Vararg{Nothing}}) = false
_any_labels(@nospecialize labels::Tuple{Vararg{Any}}) = true

"""
index_labels(x)
index_labels(x, dim)

Returns a tuple of labels assigned to each axis or a collection of labels corresponding to
each index along `dim` of `x`. Default is to simply return `nothing`.

See also: [`has_index_labels`](@ref)
"""
index_labels(x, dim) = index_labels(x, to_dims(x, dim))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the function live in ArrayInterfaceCore so that existing "named array" packages can overload it? BTW, it would be good to ping their authors to ensure they would all be OK with the API, otherwise it won't make a lot of sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing packages can overload it from ArrayInterface. If you're referring to supporting named dimensions like NamedDims.jl, then they define ArrayInterface.dimnames and to_dims maps to the appropriate dimension so they don't have to overload every method with a dim argument.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayInterface is relatively heavy, which is why ArrayInterfaceCore was created IIUC. I guess it's up to package authors to say whether a dependency on ArrayInterface is acceptable for them or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was more of an issue before we went through all the trouble to fix invalidations due to StaticInt earlier this year. That doesn't mean we can't improve the situation. We are actively trying to move matured functionality into base where appropriate (see #340). I regularly review the code here in an effort to eliminate problematic code that still exists (e.g., redundancies, generated functions, etc,). For example, once we know how this PR is going to look I can finally finish an effort to consolidate a lot of "indexing.jl" by overloading base methods instead of reimplementing a lot of what's in base already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's a bit of a chicken and egg issue. We use StaticSymbol for dimension names known at compile time so that we can use them as a point of reference in an inferrible way. It's pretty difficult to do this only relying on constant propagation (demonstrated that with static sizes here JuliaLang/julia#44538 (comment)).

If someone has a reliable solution I'm open to it. I've been trying to actively address this and related issues for years

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oxinabox is it still a concern depending on ArrayInterface at this point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that ArrayInterface doesn't like Requires.jl a billion packages, I am much more comfortable depending upon it for NamedDims.jl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion was just to put empty function definitions in ArrayInterfaceCore (like we do with DataAPI and StatsAPI). That doesn't prevent keeping fallback method definitions in ArrayInterface as this PR does. But packages that don't want to use fallback definitions are still able to overload the functions by depending only on ArrayInterfaceCore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've been trying to avoid situations where the behavior of a method is different if ArrayInterface is loaded vs just ArrayInterfaceCore.

index_labels(@nospecialize x::Number) = ()
@inline function index_labels(x, dim::CanonicalInt)
dim > ndims(x) ? nothing : getfield(index_labels(x), Int(dim))
end
@inline function index_labels(x)
if is_forwarding_wrapper(x)
index_labels(buffer(x))
else
ntuple(Returns(nothing), Val{ndims(x)}())
end
end
function index_labels(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x))
end
index_labels(x::VecAdjTrans) = (nothing, getfield(index_labels(parent(x)), 1))
function index_labels(x::SubArray)
labels = index_labels(parent(x))
inds = x.indices
info = IndicesInfo(x)
pdims = parentdims(info)
cdims = childdims(info)
flatten_tuples(ntuple(Val{nfields(pdims)}()) do i
pdim_i = getfield(pdims, i)
cdim_i = getfield(cdims, i)
index = getfield(inds, i)
if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes
index_labels(index)
elseif cdim_i === 0 # integer indexing drops axes
()
elseif pdim_i === 0 # trailing dimension
nothing
elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis
(getfield(labels, pdim_i),)
else
labels_i = getfield(labels, pdim_i)
labels_i === nothing ? index_labels(index) : (@inbounds(labels_i[index]),)
end
end)
end
index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first ∘ index_labels, x.indices)
index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x))
index_labels(@nospecialize(x::LazyAxis{:})) = (nothing,)
index_labels(x::LazyAxis{N}) where {N} = (getfield(index_labels(getfield(x, :parent)), N),)
@inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S}
if sizeof(T) === sizeof(S)
return index_labels(parent(x))
else
return (nothing, Base.tail(index_labels(parent(x)))...)
end
end
function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
_reinterpret_index_labels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x)
end
@inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray})
S = eltype(parent_type(T))
isstructtype(S) ? fieldnames(S) : ()
end
function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N}
__reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x)))
end
@inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M}
N === M ? (fields, ks...,) : (nothing, ks...,)
end
_reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x))
_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = Base.tail(index_labels(parent(x)))
9 changes: 9 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ function _init_dimsmap(@nospecialize info::IndicesInfo)
ntuple(i -> static(getfield(cdims, i)), length(pdims))
end

parentdims(::IndicesInfo{<:Any,pdims}) where {pdims} = pdims

childdims(::IndicesInfo{<:Any,<:Any,cdims}) where {cdims} = cdims

"""
to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{StaticInt,Tuple{Vararg{StaticInt}}}}}

Expand Down Expand Up @@ -148,6 +152,9 @@ end
return ntuple(Compat.Returns(:_), StaticInt(ndims(T)))
end
end
known_dimnames(::Type{<:LazyAxis{:,P}}) where {P} = (first(known_dimnames(P)),)
known_dimnames(::Type{<:LazyAxis{N,P}}) where {N,P} = (getfield(known_dimnames(P), N),)

@inline function known_dimnames(::Type{T}) where {T}
if is_forwarding_wrapper(T)
return known_dimnames(parent_type(T))
Expand Down Expand Up @@ -207,6 +214,8 @@ end
return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x)))
end
end
dimnames(x::LazyAxis{:,P}) where {P} = (first(dimnames(getfield(x, :parent))),)
dimnames(x::LazyAxis{N,P}) where {N,P} = (getfield(dimnames(getfield(x, :parent)), N),)
@inline function dimnames(x::X) where {X}
if is_forwarding_wrapper(X)
return dimnames(parent(x))
Expand Down
23 changes: 21 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,33 @@ end
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
max(canonicalize(i.x), static_first(x)):static_last(x)
end
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
findall(i.f(i.x.label), first(index_labels(x)))
end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
end
# integer indexing
to_index(x, i::AbstractArray{<:Integer}) = i
to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i
to_index(x, @nospecialize(i::StaticInt)) = i
to_index(x, i::Integer) = Int(i)
@inline to_index(x, i) = to_index(IndexStyle(x), x, i)
# key indexing
function to_index(x, i::IndexLabel)
index = findfirst(==(getfield(i, :label)), first(index_labels(x)))
# delay throwing bounds-error if we didn't find label
index === nothing ? offset1(x) - 1 : index
end
function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number})
index = findfirst(==(i), getfield(index_labels(x), 1))
index === nothing ? offset1(x) - 1 : index
end
# TODO there's probably a more efficient way of doing this
to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks]
function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}})
[to_index(x, k) for k in ks]
end

# integer indexing
function to_index(S::IndexStyle, x, i)
throw(ArgumentError(
"invalid index: $S does not support indices of type $(typeof(i)) for instances of type $(typeof(x))."
Expand Down
41 changes: 41 additions & 0 deletions test/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,44 @@ if isdefined(Base, :ReshapedReinterpretArray)
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
end
end

@testset "index_labels" begin
colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),));
caxis = ArrayInterface.LazyAxis{1}(colors);
colormat = reinterpret(reshape, Float64, colors);
cmat_view1 = view(colormat, :, 4);
cmat_view2 = view(colormat, :, 4:7);
cmat_view3 = view(colormat, 2:3,:);
absym_abstr = LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],));

@test @inferred(ArrayInterface.index_labels(colors)) == (range(-10, 10, length=100),)
@test @inferred(ArrayInterface.index_labels(caxis)) == (range(-10, 10, length=100),)
@test ArrayInterface.index_labels(view(colors, :, :), 2) === nothing
@test @inferred(ArrayInterface.index_labels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),)
@test @inferred(ArrayInterface.index_labels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100))
@test @inferred(ArrayInterface.index_labels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B))
@test @inferred(ArrayInterface.index_labels(cmat_view1)) == ((:R, :G, :B),)
@test @inferred((ArrayInterface.index_labels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787)
@test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (nothing, range(-10, 10, length=100))
# can't infer this b/c tuple is being indexed by range
@test ArrayInterface.index_labels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0)
@test @inferred(ArrayInterface.index_labels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595)

@test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (nothing, ["a", "b"])
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (nothing, [:a, :b], ["a", "b"])
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],)
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],)
@test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],)

@test ArrayInterface.has_index_labels(colors)
@test ArrayInterface.has_index_labels(caxis)
@test ArrayInterface.has_index_labels(colormat)
@test ArrayInterface.has_index_labels(cmat_view1)
@test !ArrayInterface.has_index_labels(view(colors, :, :))

@test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :]
@test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1]
@test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.IndexLabel(-9.595959595959595))) == colormat[:, 3]
@test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.IndexLabel(-9.595959595959595)))) == colormat[:, 1:3]
@test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]]
end
7 changes: 7 additions & 0 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ end
r4 = reinterpret(reshape, Float64, x)
w = Wrapper(x)
dnums = ntuple(+, length(d))
lz2 = ArrayInterface.lazy_axes(x)[2]
lzslice = ArrayInterface.LazyAxis{:}(x)

@test @inferred(ArrayInterface.has_dimnames(x)) == true
@test @inferred(ArrayInterface.has_dimnames(z)) == true
@test @inferred(ArrayInterface.has_dimnames(ones(2, 2))) == false
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) == false
@test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true
@test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true
@test @inferred(ArrayInterface.dimnames(x)) === d
@test @inferred(ArrayInterface.dimnames(lz2)) === (static(:y),)
@test @inferred(ArrayInterface.dimnames(lzslice)) === (static(:x),)
@test @inferred(ArrayInterface.dimnames(w)) === d
@test @inferred(ArrayInterface.dimnames(r1)) === d
@test @inferred(ArrayInterface.dimnames(r2)) === (static(:_), d...)
Expand Down Expand Up @@ -64,6 +69,8 @@ end
# multidmensional indices
@test @inferred(ArrayInterface.known_dimnames(view(x, ones(Int, 2, 2), 1))) === (:_, :_)
@test @inferred(ArrayInterface.known_dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(lz2)) === (:y,)
@test @inferred(ArrayInterface.known_dimnames(lzslice)) === (:x,)

@test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y)
Expand Down
11 changes: 10 additions & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L
ArrayInterface.Static.known(L)
end

Base.parent(x::NamedDimsWrapper) = x.parent
struct LabelledArray{T,N,P<:AbstractArray{T,N},L} <: ArrayInterface.AbstractArray2{T,N}
parent::P
labels::L

LabelledArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels)
end
ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true
Base.parent(x::LabelledArray) = getfield(x, :parent)
ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P
ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels)

# Dummy array type with undetermined contiguity properties
struct DummyZeros{T,N} <: AbstractArray{T,N}
Expand Down