Skip to content

Commit a8cc8d2

Browse files
authored
Quick fix for getting axes on SubArray with multidim indices (#297)
1 parent 98c1813 commit a8cc8d2

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.12"
3+
version = "6.0.13"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/axes.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,27 @@ axes(A::ReshapedArray) = Base.axes(A)
103103
axes(A::PermutedDimsArray) = permute(axes(parent(A)), to_parent_dims(A))
104104
axes(A::MatAdjTrans) = permute(axes(parent(A)), to_parent_dims(A))
105105
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
106-
axes(A::SubArray{<:Any,N}) where {N} = map(Base.Fix1(axes, A), Static.nstatic(Val(N)))
106+
axes(A::SubArray) = _sub_axes(parent(A), A.indices)
107+
@generated function _sub_axes(A, inds::I) where {N,P,I}
108+
out = Expr(:block, Expr(:meta, :inline))
109+
t = Expr(:tuple)
110+
for i in 1:fieldcount(I)
111+
I_i = fieldtype(I, i)
112+
if I_i <: Base.Slice{Base.OneTo{Int}}
113+
push!(t.args, :(axes(A, $i)))
114+
elseif ndims(I_i) === 1
115+
push!(t.args, :(getfield(axes(getfield(inds, $i)), 1)))
116+
else
117+
axsi = Symbol(:axes_, i)
118+
push!(out.args, :(axes(getfield(inds, $i))))
119+
for j in 1:ndims(I_i)
120+
push!(t.args, :(getfield($(axsi), $j)))
121+
end
122+
end
123+
end
124+
push!(out.args, t)
125+
out
126+
end
107127

108128
@inline axes(A, dim) = _axes(A, to_dims(A, dim))
109129
@inline function _axes(A, dim::Int)
@@ -120,27 +140,6 @@ end
120140
return getfield(axes(A), Int(dim))
121141
end
122142
end
123-
@inline function _inbound_axes(A::SubArray, dim)
124-
pd = to_parent_dims(A, to_dims(A, dim))
125-
ax = getindex(A.indices, pd)
126-
ax isa Base.Slice || return axes(ax, 1)
127-
axes(parent(A))[pd]
128-
end
129-
@inline function axes(A::SubArray, dim::CanonicalInt)
130-
if dim > ndims(A)
131-
return OneTo(1)
132-
else
133-
return _inbound_axes(A, dim)
134-
end
135-
end
136-
@inline function axes(A::SubArray, ::StaticInt{dim}) where {dim}
137-
if dim > ndims(A)
138-
return SOneTo{1}()
139-
else
140-
return _inbound_axes(A, StaticInt(dim))
141-
end
142-
end
143-
144143
function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}}
145144
if sizeof(S) > sizeof(T)
146145
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))

test/axes.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ end
6666
@test @inferred(ArrayInterface.axes(b, static(1))) == 1:m
6767
@test @inferred(ArrayInterface.axes(b, static(2))) == 1:1
6868
@test @inferred(ArrayInterface.axes(b, static(3))) == 1:1
69+
70+
# multidimensional subindices
71+
vx = view(rand(4), reshape(1:4, 2, 2))
72+
@test @inferred(axes(vx)) == (1:2, 1:2)
6973
end
7074

7175
@testset "SubArray Adjoint Axis" begin

0 commit comments

Comments
 (0)