Skip to content

Commit 0b5e133

Browse files
authored
Merge pull request #225 from ranocha/hr/axes_SubArray
`axes(A::SubArray, dim)` with `dim` larger than `ndims(A)`
2 parents 921bdcd + f077046 commit 0b5e133

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

src/axes.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Returns the axis associated with each dimension of `A` or dimension `dim`.
8888
exception of a handful of types replace `Base.OneTo{Int}` with `ArrayInterface.SOneTo`. For
8989
example, the axis along the first dimension of `Transpose{T,<:AbstractVector{T}}` and
9090
`Adjoint{T,<:AbstractVector{T}}` can be represented by `SOneTo(1)`. Similarly,
91-
`Base.ReinterpretArray`'s first axis may be statically sized.
91+
`Base.ReinterpretArray`'s first axis may be statically sized.
9292
"""
9393
@inline axes(A) = Base.axes(A)
9494
axes(A::ReshapedArray) = Base.axes(A)
@@ -112,7 +112,22 @@ end
112112
return getfield(axes(A), Int(dim))
113113
end
114114
end
115-
axes(A::SubArray, dim) = Base.axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1)
115+
116+
@inline function axes(A::SubArray, dim::Integer)
117+
if dim > ndims(A)
118+
return OneTo(1)
119+
else
120+
return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1)
121+
end
122+
end
123+
@inline function axes(A::SubArray, ::StaticInt{dim}) where {dim}
124+
if dim > ndims(A)
125+
return SOneTo{1}()
126+
else
127+
return axes(getindex(A.indices, to_parent_dims(A, to_dims(A, dim))), 1)
128+
end
129+
end
130+
116131
if isdefined(Base, :ReshapedReinterpretArray)
117132
function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}}
118133
if sizeof(S) > sizeof(T)

test/axes.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,29 @@ m = Array{Float64}(undef, 4, 3)
4444
@test_throws DimensionMismatch ArrayInterface.LazyAxis{0}(A)
4545
end
4646

47+
@testset "`axes(A, dim)`` with `dim > ndims(A)` (#224)" begin
48+
m = 2
49+
n = 3
50+
B = Array{Float64, 2}(undef, m, n)
51+
b = view(B, :, 1)
52+
53+
@test @inferred(ArrayInterface.axes(B, 1)) == 1:m
54+
@test @inferred(ArrayInterface.axes(B, 2)) == 1:n
55+
@test @inferred(ArrayInterface.axes(B, 3)) == 1:1
56+
57+
@test @inferred(ArrayInterface.axes(B, static(1))) == 1:m
58+
@test @inferred(ArrayInterface.axes(B, static(2))) == 1:n
59+
@test @inferred(ArrayInterface.axes(B, static(3))) == 1:1
60+
61+
@test @inferred(ArrayInterface.axes(b, 1)) == 1:m
62+
@test @inferred(ArrayInterface.axes(b, 2)) == 1:1
63+
@test @inferred(ArrayInterface.axes(b, 3)) == 1:1
64+
65+
@test @inferred(ArrayInterface.axes(b, static(1))) == 1:m
66+
@test @inferred(ArrayInterface.axes(b, static(2))) == 1:1
67+
@test @inferred(ArrayInterface.axes(b, static(3))) == 1:1
68+
end
69+
4770
if isdefined(Base, :ReshapedReinterpretArray)
4871
a = rand(3, 5)
4972
ua = reinterpret(reshape, UInt64, a)

0 commit comments

Comments
 (0)