Skip to content

Commit fa1e749

Browse files
committed
Layer the axes implementation
1 parent 5ee8108 commit fa1e749

File tree

2 files changed

+21
-25
lines changed

2 files changed

+21
-25
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,9 +1163,9 @@ function __init__()
11631163
function stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray}
11641164
return stride_rank(parent_type(A))
11651165
end
1166-
axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
1167-
axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
1168-
axes(A::OffsetArrays.OffsetArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
1166+
@inline axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
1167+
@inline _axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
1168+
@inline axes(A::OffsetArrays.OffsetArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
11691169
end
11701170
end
11711171

src/axes.jl

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,45 +120,41 @@ similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N1}}}, ::Type{Int},
120120
121121
Return a valid range that maps to each index along dimension `d` of `A`.
122122
"""
123-
axes(a, dim) = axes(a, to_dims(a, dim))
124-
axes(a, dims::Tuple{Vararg{Any,K}}) where {K} = (axes(a, first(dims)), axes(a, tail(dims))...)
125-
axes(a, dims::Tuple{T}) where {T} = (axes(a, first(dims)), )
126-
axes(a, ::Tuple{}) = ()
127-
function axes(a::A, dim::Integer) where {A}
123+
@inline axes(a, dim) = axes(a, to_dims(a, dim))
124+
@inline axes(a, dims::Tuple{Vararg{Any,K}}) where {K} = (axes(a, first(dims)), axes(a, tail(dims))...)
125+
@inline axes(a, dims::Tuple{T}) where {T} = (axes(a, first(dims)), )
126+
@inline axes(a, ::Tuple{}) = ()
127+
@inline function _axes(a::A, dim::Integer) where {A}
128128
if parent_type(A) <: A
129129
return Base.axes(a, Int(dim))
130130
else
131-
return axes(parent(a), to_parent_dims(A, dim))
131+
return _axes(parent(a), to_parent_dims(A, dim))
132132
end
133133
end
134-
function axes(A::CartesianIndices{N}, dim::Integer) where {N}
134+
@inline function _axes(A::CartesianIndices{N}, dim::Integer) where {N}
135135
if dim > N
136136
return static(1):static(1)
137137
else
138138
return getfield(axes(A), Int(dim))
139139
end
140140
end
141-
function axes(A::LinearIndices{N}, dim::Integer) where {N}
141+
@inline function _axes(A::LinearIndices{N}, dim::Integer) where {N}
142142
if dim > N
143143
return static(1):static(1)
144144
else
145145
return getfield(axes(A), Int(dim))
146146
end
147147
end
148-
axes(::LinearAlgebra.AdjOrTrans{T,V}, ::One) where {T,V<:AbstractVector} = One():One()
149-
axes(A::AbstractArray{T,N}, ::StaticInt{M}) where {T,N,M} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
150-
axes(A::CartesianIndices{N}, ::StaticInt{M}) where {N,M} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
151-
axes(A::LinearIndices{N}, ::StaticInt{M}) where {N,M} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
152-
_axes(::Any, ::Any, ::True) = One():One()
153-
_axes(A::AbstractArray, ::StaticInt{M}, ::False) where {M} = axes(A, M)
154-
155-
156-
axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
157-
axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
158-
axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
159-
axes(A::SubArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
160-
axes(A::ReinterpretArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
161-
axes(A::Base.ReshapedArray{T,N}, ::StaticInt{M}) where {T,M,N} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
148+
@inline _axes(::LinearAlgebra.AdjOrTrans{T,V}, ::One) where {T,V<:AbstractVector} = One():One()
149+
@inline axes(A::AbstractArray, dim::Integer) = _axes(A, dim, False())
150+
@inline axes(A::AbstractArray{T,N}, ::StaticInt{M}) where {T,N,M} = _axes(A, StaticInt{M}(), gt(StaticInt{M}(),StaticInt{N}()))
151+
@inline _axes(::Any, ::Any, ::True) = One():One()
152+
@inline _axes(A::AbstractArray, dim, ::False) = _axes(A, dim)
153+
154+
155+
@inline _axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
156+
@inline _axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
157+
@inline _axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
162158

163159
"""
164160
axes(A)

0 commit comments

Comments
 (0)