Skip to content

Commit b8fbb3b

Browse files
authored
Cleanup/fix range and dimension related traits/methods (#196)
axes_types was erroneously assuming that the the parent axis always had an offfset of 1 known_first and known_last work for CartesianIndices now defined contiguous_axis, stride_rank, and offset1 for StrideIndex has_dimnames returns StaticBool now. Not breaking because no downstream dependencies yet.
1 parent 561848d commit b8fbb3b

File tree

11 files changed

+74
-43
lines changed

11 files changed

+74
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.27"
3+
version = "3.1.28"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/ArrayInterface.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr
1414

1515

1616
## utilites for internal use only ##
17-
_int_or_static_int(::Nothing) = Int
18-
_int_or_static_int(x::Int) = StaticInt{x}
19-
2017
@static if VERSION >= v"1.7.0-DEV.421"
2118
using Base: @aggressive_constprop
2219
else

src/array_index.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
194194
offsets::O
195195

196196
function StrideIndex{N,R,C}(s::S, o::O) where {N,R,C,S,O}
197-
return new{N,R::NTuple{N,Int},C::Int,S,O}(s, o)
197+
return new{N,R::NTuple{N,Int},C,S,O}(s, o)
198198
end
199199
function StrideIndex{N,R,C}(a::A) where {N,R,C,A}
200200
return StrideIndex{N,R,C}(strides(a), offsets(a))

src/axes.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11

2+
_static_range_type(::Any, ::Any) = OptionallyStaticUnitRange{Int,Int}
3+
_static_range_type(start::Int, ::Nothing) = OptionallyStaticUnitRange{StaticInt{start},Int}
4+
function _static_range_type(start::Int, size::Int)
5+
OptionallyStaticUnitRange{StaticInt{start},StaticInt{(size - 1) + start}}
6+
end
7+
28
"""
39
axes_types(::Type{T}) -> Type{Tuple{Vararg{AbstractUnitRange{Int}}}}
410
axes_types(::Type{T}, dim) -> Type{AbstractUnitRange{Int}}
@@ -54,11 +60,9 @@ end
5460
@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
5561
return eachop_tuple(_sub_axis_type, to_parent_dims(T), T)
5662
end
63+
5764
@inline function _sub_axis_type(::Type{A}, dim::StaticInt) where {T,N,P,I,A<:SubArray{T,N,P,I}}
58-
return OptionallyStaticUnitRange{
59-
_int_or_static_int(known_first(axes_types(P, dim))),
60-
_int_or_static_int(known_length(_get_tuple(I, dim)))
61-
}
65+
_static_range_type(known_first(axes_types(P, dim)),known_length(_get_tuple(I, dim)))
6266
end
6367

6468
function axes_types(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}

src/dimensions.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ from_parent_dims(x) = from_parent_dims(typeof(x))
4646
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
4747
from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),)
4848
from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One())
49-
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
50-
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,I<:Tuple}
49+
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I)
50+
@generated function _from_sub_dims(::Type{I}) where {I<:Tuple}
5151
out = Expr(:tuple)
5252
dim_i = 1
53-
for i in 1:ndims(A)
53+
for i in 1:length(I.parameters)
5454
p = I.parameters[i]
5555
if p <: Integer
5656
push!(out.args, :(StaticInt(0)))
@@ -103,8 +103,8 @@ to_parent_dims(x) = to_parent_dims(typeof(x))
103103
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
104104
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
105105
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = static(Val(I))
106-
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
107-
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
106+
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
107+
@generated function _to_sub_dims(::Type{I}) where {I<:Tuple}
108108
out = Expr(:tuple)
109109
n = 1
110110
for p in I.parameters
@@ -150,16 +150,12 @@ end
150150
"""
151151
has_dimnames(::Type{T}) -> Bool
152152
153-
Returns `true` if `x` has names for each dimension.
153+
Returns `static(true)` if `x` has on or more named dimensions.
154154
"""
155-
@inline has_dimnames(x) = has_dimnames(typeof(x))
156-
function has_dimnames(::Type{T}) where {T}
157-
if parent_type(T) <: T
158-
return false
159-
else
160-
return has_dimnames(parent_type(T))
161-
end
162-
end
155+
has_dimnames(x) = has_dimnames(typeof(x))
156+
@inline has_dimnames(::Type{T}) where {T} = _has_dimnames(dimnames(T))
157+
_has_dimnames(::Tuple{Vararg{StaticSymbol{:_}}}) = static(false)
158+
_has_dimnames(::Tuple) = static(true)
163159

164160
# this takes the place of dimension names that aren't defined
165161
const SUnderscore = StaticSymbol(:_)

src/ranges.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
_cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i)
3+
_cartesian_index(::Any) = nothing
4+
25
"""
36
known_first(::Type{T}) -> Union{Int,Nothing}
47
@@ -22,6 +25,10 @@ function known_first(::Type{T}) where {T}
2225
end
2326
end
2427
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
28+
function known_first(::Type{T}) where {N,R,T<:CartesianIndices{N,R}}
29+
_cartesian_index(ntuple(i -> known_first(R.parameters[i]), Val(N)))
30+
end
31+
2532

2633
"""
2734
known_last(::Type{T}) -> Union{Int,Nothing}
@@ -46,6 +53,9 @@ function known_last(::Type{T}) where {T}
4653
return known_last(parent_type(T))
4754
end
4855
end
56+
function known_last(::Type{T}) where {N,R,T<:CartesianIndices{N,R}}
57+
_cartesian_index(ntuple(i -> known_last(R.parameters[i]), Val(N)))
58+
end
4959

5060
"""
5161
known_step(::Type{T}) -> Union{Int,Nothing}

src/stridelayout.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ function _offsets(x::X, dim::StaticInt{D}) where {X,D}
6464
return static(start)
6565
end
6666
end
67+
# we can't generate an axis for `StrideIndex` so this is performed manually here
68+
@inline offsets(x::StrideIndex, dim::Int) = getfield(offsets(x), dim)
69+
@inline offsets(x::StrideIndex, ::StaticInt{dim}) where {dim} = getfield(offsets(x), dim)
6770

6871
"""
6972
known_offset1(::Type{T}) -> Union{Int,Nothing}
@@ -105,6 +108,8 @@ If no axis is contiguous, it returns a `StaticInt{-1}`.
105108
If unknown, it returns `nothing`.
106109
"""
107110
contiguous_axis(x) = contiguous_axis(typeof(x))
111+
contiguous_axis(::Type{<:StrideIndex{N,R,C}}) where {N,R,C} = static(C)
112+
contiguous_axis(::Type{<:StrideIndex{N,R,nothing}}) where {N,R} = nothing
108113
function contiguous_axis(::Type{T}) where {T}
109114
if parent_type(T) <: T
110115
return nothing
@@ -197,6 +202,7 @@ function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N}
197202
return sp
198203
end
199204

205+
stride_rank(::Type{<:StrideIndex{N,R}}) where {N,R} = static(R)
200206
stride_rank(x) = stride_rank(typeof(x))
201207
function stride_rank(::Type{T}) where {T}
202208
if parent_type(T) <: T

test/array_index.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
A = zeros(3, 4, 5);
3+
A[:] = 1:60
4+
Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])';
5+
6+
ap_index = ArrayInterface.StrideIndex(Ap)
7+
for x_i in axes(Ap, 1)
8+
for y_i in axes(Ap, 2)
9+
@test ap_index[x_i, y_i] == ap_index[x_i, y_i]
10+
end
11+
end
12+
@test @inferred(ArrayInterface.known_offsets(ap_index)) === ArrayInterface.known_offsets(Ap)
13+
@test @inferred(ArrayInterface.known_offset1(ap_index)) === ArrayInterface.known_offset1(Ap)
14+
@test @inferred(ArrayInterface.offsets(ap_index, 1)) === ArrayInterface.offset1(Ap)
15+
@test @inferred(ArrayInterface.offsets(ap_index, static(1))) === ArrayInterface.offset1(Ap)
16+
@test @inferred(ArrayInterface.known_strides(ap_index)) === ArrayInterface.known_strides(Ap)
17+
@test @inferred(ArrayInterface.contiguous_axis(ap_index)) == 1
18+
@test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing
19+
@test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3)
20+

test/dimensions.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractA
1010
NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p)
1111
end
1212
ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P
13-
ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true
1413
ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = static(Val(L))
1514
function ArrayInterface.dimnames(::Type{T}, dim) where {L,T<:NamedDimsWrapper{L}}
1615
if ndims(T) < dim
@@ -19,7 +18,6 @@ function ArrayInterface.dimnames(::Type{T}, dim) where {L,T<:NamedDimsWrapper{L}
1918
return static(L[dim])
2019
end
2120
end
22-
ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true
2321
Base.parent(x::NamedDimsWrapper) = x.parent
2422

2523
@testset "dimension permutations" begin
@@ -90,18 +88,17 @@ end
9088
@test_throws ErrorException ArrayInterface.order_named_inds(n2, (x=30, y=20, z=40))
9189
end
9290

93-
val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x))
9491

9592
@testset "dimnames" begin
9693
d = (static(:x), static(:y))
9794
x = NamedDimsWrapper{d}(ones(2,2));
9895
y = NamedDimsWrapper{(:x,)}(ones(2));
9996
dnums = ntuple(+, length(d))
100-
@test @inferred(val_has_dimnames(x)) === Val(true)
101-
@test @inferred(ArrayInterface.has_dimnames(ones(2,2))) === false
102-
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) === false
103-
@test @inferred(val_has_dimnames(typeof(x))) === Val(true)
104-
@test @inferred(val_has_dimnames(typeof(view(x, :, 1, :)))) === Val(true)
97+
@test @inferred(ArrayInterface.has_dimnames(x)) == true
98+
@test @inferred(ArrayInterface.has_dimnames(ones(2,2))) == false
99+
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) == false
100+
@test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true
101+
@test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true
105102
@test @inferred(dimnames(x)) === d
106103
@test @inferred(dimnames(parent(x))) === (static(:_), static(:_))
107104
@test @inferred(dimnames(x')) === reverse(d)

test/ranges.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@
5757
@test isnothing(@inferred(ArrayInterface.known_last(typeof(1:4))))
5858
@test isone(@inferred(ArrayInterface.known_last(typeof(StaticInt(-1):StaticInt(2):StaticInt(1)))))
5959

60+
# CartesianIndices
61+
CI = CartesianIndices((2, 2))
62+
@test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1)
63+
@test @inferred(ArrayInterface.known_last(typeof(CI))) == nothing
64+
65+
CI = CartesianIndices((static(1):static(2), static(1):static(2)))
66+
@test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1)
67+
@test @inferred(ArrayInterface.known_last(typeof(CI))) == CartesianIndex(2, 2)
68+
6069
@test isnothing(@inferred(ArrayInterface.known_step(typeof(1:0.2:4))))
6170
@test isone(@inferred(ArrayInterface.known_step(1:4)))
6271
@test isone(@inferred(ArrayInterface.known_step(typeof(1:4))))

0 commit comments

Comments
 (0)