Skip to content

Commit 8d89c2f

Browse files
authored
Clean axes.jl file (#288)
* Clean axes.jl file Most of this is despecialization for `LazyAxis` b/c I've found certain methods will randomly be invalidated, depending on which build of Julia is used. There's also some improved handling of `IdentityUnitRange` here.
1 parent 09eeb2f commit 8d89c2f

File tree

4 files changed

+20
-44
lines changed

4 files changed

+20
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.7"
3+
version = "6.0.8"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/ArrayInterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ function Base.axes(A::AbstractArray2)
5151
!(parent_type(A) <: typeof(A)) && return ArrayInterface.axes(parent(A))
5252
throw(ArgumentError("Subtypes of `AbstractArray2` must define an axes method"))
5353
end
54-
Base.axes(A::AbstractArray2, dim::Union{Symbol,StaticSymbol}) = Base.axes(A, to_dims(A, dim))
54+
function Base.axes(A::AbstractArray2, dim::Union{Symbol,StaticSymbol})
55+
axes(A, to_dims(A, dim))
56+
end
5557

5658
function Base.strides(A::AbstractArray2)
5759
defines_strides(A) && return map(Int, ArrayInterface.strides(A))

src/axes.jl

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ axes_types(::Type{T}) where {T<:Array} = NTuple{ndims(T),OneTo{Int}}
2929
return axes_types(parent_type(T))
3030
end
3131
end
32-
axes_types(::Type{LinearIndices{N,R}}) where {N,R} = R
33-
axes_types(::Type{CartesianIndices{N,R}}) where {N,R} = R
32+
axes_types(::Type{<:LinearIndices{N,R}}) where {N,R} = R
33+
axes_types(::Type{<:CartesianIndices{N,R}}) where {N,R} = R
3434
function axes_types(::Type{T}) where {T<:VecAdjTrans}
3535
Tuple{SOneTo{1},axes_types(parent_type(T), static(1))}
3636
end
@@ -40,7 +40,9 @@ end
4040
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
4141
eachop_tuple(field_type, to_parent_dims(T), axes_types(parent_type(T)))
4242
end
43+
axes_types(T::Type{<:Base.IdentityUnitRange}) = Tuple{T}
4344
axes_types(::Type{<:Base.Slice{I}}) where {I} = Tuple{Base.IdentityUnitRange{I}}
45+
axes_types(::Type{<:Base.Slice{I}}) where {I<:Base.IdentityUnitRange} = Tuple{I}
4446
function axes_types(::Type{T}) where {T<:AbstractRange}
4547
if known_length(T) === nothing
4648
return Tuple{OneTo{Int}}
@@ -207,49 +209,30 @@ end
207209

208210
Base.keys(x::LazyAxis) = keys(parent(x))
209211

210-
Base.IndexStyle(::Type{T}) where {T<:LazyAxis} = IndexStyle(parent_type(T))
212+
Base.IndexStyle(::Type{<:LazyAxis}) = IndexStyle(parent_type(T))
211213

212-
can_change_size(::Type{LazyAxis{N,P}}) where {N,P} = can_change_size(P)
214+
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))
213215

214-
ArrayInterfaceCore.known_first(::Type{LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
215-
ArrayInterfaceCore.known_first(::Type{LazyAxis{:,P}}) where {P} = 1
216-
Base.firstindex(x::LazyAxis) = first(x)
216+
ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
217+
ArrayInterfaceCore.known_first(::Type{<:LazyAxis{:,P}}) where {P} = 1
217218
@inline function Base.first(x::LazyAxis{N})::Int where {N}
218219
if ArrayInterfaceCore.known_first(x) === nothing
219-
return Int(offsets(parent(x), static(N)))
220+
return Int(offsets(parent(x), StaticInt(N)))
220221
else
221222
return Int(known_first(x))
222223
end
223224
end
224-
@inline function Base.first(x::LazyAxis{:})::Int
225-
if known_first(x) === nothing
226-
return first(parent(x))
227-
else
228-
return known_first(x)
229-
end
230-
end
225+
@inline Base.first(x::LazyAxis{:})::Int = Int(offset1(getfield(x, :parent)))
231226
ArrayInterfaceCore.known_last(::Type{LazyAxis{N,P}}) where {N,P} = known_last(axes_types(P, static(N)))
232227
ArrayInterfaceCore.known_last(::Type{LazyAxis{:,P}}) where {P} = known_length(P)
233228
Base.last(x::LazyAxis) = _last(known_last(x), x)
234229
_last(::Nothing, x) = last(parent(x))
235230
_last(N::Int, x) = N
236231

237-
known_length(::Type{LazyAxis{N,P}}) where {N,P} = known_size(P, static(N))
238-
known_length(::Type{LazyAxis{:,P}}) where {P} = known_length(P)
239-
@inline function Base.length(x::LazyAxis{N})::Int where {N}
240-
if known_length(x) === nothing
241-
return size(getfield(x, :parent), static(N))
242-
else
243-
return known_length(x)
244-
end
245-
end
246-
@inline function Base.length(x::LazyAxis{:})::Int
247-
if known_length(x) === nothing
248-
return length(parent(x))
249-
else
250-
return known_length(x)
251-
end
252-
end
232+
known_length(::Type{<:LazyAxis{:,P}}) where {P} = known_length(P)
233+
known_length(::Type{<:LazyAxis{N,P}}) where {N,P} = known_size(P, static(N))
234+
@inline Base.length(x::LazyAxis{:}) = Base.length(getfield(x, :parent))
235+
@inline Base.length(x::LazyAxis{N}) where {N} = Base.size(getfield(x, :parent), N)
253236

254237
Base.axes(x::LazyAxis) = (Base.axes1(x),)
255238
Base.axes1(x::LazyAxis) = x
@@ -259,14 +242,6 @@ Base.axes(x::Slice{<:LazyAxis}) = (Base.axes1(x),)
259242
Base.axes1(x::Slice{<:LazyAxis}) = indices(parent(x.indices))
260243
Base.to_shape(x::LazyAxis) = length(x)
261244

262-
@inline function Base.checkindex(::Type{Bool}, x::LazyAxis, i::CanonicalInt)
263-
if known_first(x) === nothing || known_last(x) === nothing
264-
return checkindex(Bool, parent(x), i)
265-
else # everything is static so we don't have to retrieve the axis
266-
return (!(known_first(x) > i) || !(known_last(x) < i))
267-
end
268-
end
269-
270245
@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
271246
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
272247
return Int(i)
@@ -288,8 +263,7 @@ constructed or it is simply retrieved.
288263
@generated function lazy_axes(x::X) where {X}
289264
Expr(:block, Expr(:meta, :inline), Expr(:tuple, [:(LazyAxis{$dim}(x)) for dim in 1:ndims(X)]...))
290265
end
291-
lazy_axes(x::LinearIndices) = axes(x)
292-
lazy_axes(x::CartesianIndices) = axes(x)
266+
lazy_axes(x::Union{LinearIndices,CartesianIndices,AbstractRange}) = axes(x)
293267
@inline lazy_axes(x::MatAdjTrans) = reverse(lazy_axes(parent(x)))
294268
@inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x))))
295269
@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x))

test/axes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ m = Array{Float64}(undef, 4, 3)
66
@test @inferred(ArrayInterface.axes(m')) === (Base.OneTo(3),Base.OneTo(4))
77
@test ArrayInterface.axes(v', StaticInt(1)) === StaticInt(1):StaticInt(1)
88
@test ArrayInterface.axes(v, StaticInt(2)) === StaticInt(1):StaticInt(1)
9+
@test ArrayInterface.axes_types(view(CartesianIndices(map(Base.Slice, (0:3, 3:5))), 0, :), 1) <: Base.IdentityUnitRange
910

1011
@testset "LazyAxis" begin
1112
A = zeros(3,4,5);
@@ -88,4 +89,3 @@ if isdefined(Base, :ReshapedReinterpretArray)
8889
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
8990
end
9091
end
91-

0 commit comments

Comments
 (0)