Skip to content

Commit 8d5eeb6

Browse files
authored
Merge pull request #167 from JuliaArrays/static_axis_1_for_adj_or_trans
Static axis 1 for adj or trans
2 parents 2a5777f + fa1e749 commit 8d5eeb6

File tree

4 files changed

+27
-14
lines changed

4 files changed

+27
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.16"
3+
version = "3.1.17"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/ArrayInterface.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,8 +1163,9 @@ function __init__()
11631163
function stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray}
11641164
return stride_rank(parent_type(A))
11651165
end
1166-
ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
1167-
ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
1166+
@inline axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
1167+
@inline _axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
1168+
@inline axes(A::OffsetArrays.OffsetArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
11681169
end
11691170
end
11701171

src/axes.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,35 +120,41 @@ similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N1}}}, ::Type{Int},
120120
121121
Return a valid range that maps to each index along dimension `d` of `A`.
122122
"""
123-
axes(a, dim) = axes(a, to_dims(a, dim))
124-
axes(a, dims::Tuple{Vararg{Any,K}}) where {K} = (axes(a, first(dims)), axes(a, tail(dims))...)
125-
axes(a, dims::Tuple{T}) where {T} = (axes(a, first(dims)), )
126-
axes(a, ::Tuple{}) = ()
127-
function axes(a::A, dim::Integer) where {A}
123+
@inline axes(a, dim) = axes(a, to_dims(a, dim))
124+
@inline axes(a, dims::Tuple{Vararg{Any,K}}) where {K} = (axes(a, first(dims)), axes(a, tail(dims))...)
125+
@inline axes(a, dims::Tuple{T}) where {T} = (axes(a, first(dims)), )
126+
@inline axes(a, ::Tuple{}) = ()
127+
@inline function _axes(a::A, dim::Integer) where {A}
128128
if parent_type(A) <: A
129129
return Base.axes(a, Int(dim))
130130
else
131-
return axes(parent(a), to_parent_dims(A, dim))
131+
return _axes(parent(a), to_parent_dims(A, dim))
132132
end
133133
end
134-
function axes(A::CartesianIndices{N}, dim::Integer) where {N}
134+
@inline function _axes(A::CartesianIndices{N}, dim::Integer) where {N}
135135
if dim > N
136136
return static(1):static(1)
137137
else
138138
return getfield(axes(A), Int(dim))
139139
end
140140
end
141-
function axes(A::LinearIndices{N}, dim::Integer) where {N}
141+
@inline function _axes(A::LinearIndices{N}, dim::Integer) where {N}
142142
if dim > N
143143
return static(1):static(1)
144144
else
145145
return getfield(axes(A), Int(dim))
146146
end
147147
end
148+
@inline _axes(::LinearAlgebra.AdjOrTrans{T,V}, ::One) where {T,V<:AbstractVector} = One():One()
149+
@inline axes(A::AbstractArray, dim::Integer) = _axes(A, dim, False())
150+
@inline axes(A::AbstractArray{T,N}, ::StaticInt{M}) where {T,N,M} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
151+
@inline _axes(::Any, ::Any, ::True) = One():One()
152+
@inline _axes(A::AbstractArray, dim, ::False) = _axes(A, dim)
148153

149-
axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
150-
axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
151-
axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
154+
155+
@inline _axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
156+
@inline _axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
157+
@inline _axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
152158

153159
"""
154160
axes(A)

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,12 @@ end
835835
@test axes(Base.Slice(StaticInt(2):4)) === (Base.IdentityUnitRange(StaticInt(2):4),)
836836
@test Base.axes1(ArrayInterface.indices(ones(2,2))) === StaticInt(1):4
837837
@test Base.axes1(Base.Slice(StaticInt(2):4)) === Base.IdentityUnitRange(StaticInt(2):4)
838+
839+
x = vec(A23); y = vec(A32);
840+
@test ArrayInterface.indices((x',y'),StaticInt(1)) === Base.Slice(StaticInt(1):StaticInt(1))
841+
@test ArrayInterface.axes(x',StaticInt(1)) === StaticInt(1):StaticInt(1)
842+
@test ArrayInterface.indices((x,y),StaticInt(2)) === Base.Slice(StaticInt(1):StaticInt(1))
843+
@test ArrayInterface.axes(x,StaticInt(2)) === StaticInt(1):StaticInt(1)
838844
end
839845

840846
@testset "insert/deleteat" begin

0 commit comments

Comments
 (0)