diff --git a/src/axes.jl b/src/axes.jl index 9fdb4f787..8b8499875 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -77,8 +77,8 @@ function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D} end end -# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases -# breaking changes for this adopting this method to situations where they clearly benefit +# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is intended to decrease +# breaking changes for adapting this method to situations where there's clearly benefit # from the propagation of static axes. This creates the somewhat awkward situation of # conditionally typed (but inferrable) axes. It also means we can't depend on constant # propagation to preserve statically sized axes. This should probably be addressed before @@ -293,3 +293,37 @@ lazy_axes(x::CartesianIndices) = axes(x) @inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x)))) @inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x)) +""" + axes_keys(x) + axes_keys(x, dim) + + +Returns a tuple of keys assigned to each axis or the axis at dimension `dim` for `x`. Default +is to simply return `map(keys, axes(x))`. +""" +@inline axes_keys(A) = map(keys, axes(A)) +axes_keys(A::PermutedDimsArray) = permute(axes_keys(parent(A)), to_parent_dims(A)) +axes_keys(A::MatAdjTrans) = permute(axes_keys(parent(A)), to_parent_dims(A)) +axes_keys(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1)) +function axes_axes(A::SubArray{T,N}) where {T,N} + pdims = to_parent_dims(A) + ntuple(dim -> axes_keys(parent(A), pdims[dim])[A.indices[dim]], Val(N)) +end + +axes_keys(A, dim) = axes_keys(A, to_dims(A, dim)) +@inline function axes_keys(A, dim::Int) + if dim > ndims(A) + return OneTo(1) + else + return getfield(axes_keys(A), dim) + end +end +@inline function axes_keys(A, ::StaticInt{dim}) where {dim} + if dim > ndims(A) + return SOneTo{1}() + else + return getfield(axes(A), dim) + end +end +axes_keys(axis::LazyAxis{N,P}) where {N,P} = axes_keys(getfield(x, :parent), static(N)) +