Skip to content

Commit 0bcc1b5

Browse files
authored
Support static info on iterators (#236)
* Support static info on iterators * New support for types from `Base.Iterators` * Check for `Base.IteratorSize` before assuming we can rely on axes when finding the size. * Update internal calculation of range length to match base * Use `Iterators.Pairs` for v1.6 support * Test size error and Iterators.product * Use fall backs to `static_length` and `known_length` Use length methods for all non `HasShape` iterators and error or evaluate there
1 parent ca6cfd1 commit 0bcc1b5

File tree

6 files changed

+121
-54
lines changed

6 files changed

+121
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "4"
3+
version = "4.0.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/ArrayInterface.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,13 @@ known_length(x) = known_length(typeof(x))
9090
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
9191
known_length(::Type{T}) where {T<:Slice} = known_length(parent_type(T))
9292
known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N
93-
known_length(::Type{T}) where {Itr,T<:Base.Generator{Itr}} = known_length(Itr)
9493
known_length(::Type{<:Number}) = 1
9594
known_length(::Type{<:AbstractCartesianIndex{N}}) where {N} = N
96-
function known_length(::Type{T}) where {T}
97-
if parent_type(T) <: T
98-
return missing
99-
else
100-
return prod(known_size(T))
101-
end
95+
known_length(::Type{T}) where {T} = _maybe_known_length(Base.IteratorSize(T), T)
96+
_maybe_known_length(::Base.HasShape, ::Type{T}) where {T} = prod(known_size(T))
97+
_maybe_known_length(::Base.IteratorSize, ::Type) = missing
98+
function known_length(::Type{<:Iterators.Flatten{I}}) where {I}
99+
known_length(I) * known_length(eltype(I))
102100
end
103101

104102
"""

src/dimensions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@ function throw_dim_error(@nospecialize(x), @nospecialize(dim))
33
throw(DimensionMismatch("$x does not have dimension corresponding to $dim"))
44
end
55

6+
@propagate_inbounds function _promote_shape(a::Tuple{A,Vararg{Any}}, b::Tuple{B,Vararg{Any}}) where {A,B}
7+
(_try_static(getfield(a, 1), getfield(b, 1)), _promote_shape(tail(a), tail(b))...)
8+
end
9+
_promote_shape(::Tuple{}, ::Tuple{}) = ()
10+
@propagate_inbounds function _promote_shape(::Tuple{}, b::Tuple{B}) where {B}
11+
(_try_static(static(1), getfield(b, 1)),)
12+
end
13+
@propagate_inbounds function _promote_shape(a::Tuple{A}, ::Tuple{}) where {A}
14+
(_try_static(static(1), getfield(a, 1)),)
15+
end
16+
@propagate_inbounds function Base.promote_shape(a::Tuple{Vararg{CanonicalInt}}, b::Tuple{Vararg{CanonicalInt}})
17+
_promote_shape(a, b)
18+
end
19+
620
#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10)))
721
# 0.045 ns (0 allocations: 0 bytes)
822
#ArrayInterface.True()

src/ranges.jl

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -304,66 +304,48 @@ end
304304
val::Int
305305
end
306306

307+
@noinline unequal_error(x,y) = @assert false "Unequal Indices: x == $x != $y == y"
308+
@inline check_equal(x, y) = x == y || unequal_error(x,y)
309+
_try_static(::Missing, ::Missing) = missing
310+
_try_static(x::Int, ::Missing) = x
311+
_try_static(::Missing, y::Int) = y
307312
@inline _try_static(::StaticInt{N}, ::StaticInt{N}) where {N} = StaticInt{N}()
308313
@inline function _try_static(::StaticInt{M}, ::StaticInt{N}) where {M,N}
309314
@assert false "Unequal Indices: StaticInt{$M}() != StaticInt{$N}()"
310315
end
311-
@noinline unequal_error(x,y) = @assert false "Unequal Indices: x == $x != $y == y"
312-
@inline function check_equal(x, y)
313-
x == y || unequal_error(x,y)
314-
end
315-
@propagate_inbounds function _try_static(::StaticInt{N}, x) where {N}
316-
@boundscheck check_equal(StaticInt{N}(), x)
317-
return StaticInt{N}()
318-
end
319-
@propagate_inbounds function _try_static(x, ::StaticInt{N}) where {N}
320-
@boundscheck check_equal(x, StaticInt{N}())
321-
return StaticInt{N}()
322-
end
316+
@propagate_inbounds _try_static(::StaticInt{N}, x) where {N} = static(_try_static(N, x))
317+
@propagate_inbounds _try_static(x, ::StaticInt{N}) where {N} = static(_try_static(N, x))
323318
@propagate_inbounds function _try_static(x, y)
324319
@boundscheck check_equal(x, y)
325320
return x
326321
end
327322

328323
## length
329-
@inline function known_length(::Type{T}) where {T<:OptionallyStaticUnitRange}
330-
return _range_length(known_first(T), known_last(T))
331-
end
332-
333-
@inline function known_length(::Type{T}) where {T<:OptionallyStaticStepRange}
334-
_range_length(known_first(T), known_step(T), known_last(T))
335-
end
336-
337324
Base.lastindex(x::OptionallyStaticRange) = length(x)
338-
Base.length(r::OptionallyStaticUnitRange) = _range_length(static_first(r), static_last(r))
339-
@inline function Base.length(r::OptionallyStaticStepRange)
325+
@inline function Base.length(r::OptionallyStaticUnitRange)
340326
if isempty(r)
341327
return 0
342328
else
343-
return _range_length(static_first(r), static_step(r), static_last(r))
329+
return last(r) - first(r) + 1
344330
end
345331
end
346-
_range_length(start, stop) = missing
347-
function _range_length(start::CanonicalInt, stop::CanonicalInt)
348-
if start > stop
349-
return 0
350-
else
351-
return Int((stop - start) + 1)
352-
end
353-
end
354-
_range_length(start::CanonicalInt, ::One, stop::CanonicalInt) = _range_length(start, stop)
355-
@inline function _range_length(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt)
356-
if step > 1
357-
return Base.checked_add(Int(div(unsigned(stop - start), step)), 1)
358-
elseif step < -1
359-
return Base.checked_add(Int(div(unsigned(start - stop), -step)), 1)
360-
elseif step > 0
361-
return Base.checked_add(Int(div(Base.checked_sub(stop, start), step)), 1)
332+
Base.length(r::OptionallyStaticStepRange) = _range_length(first(r), step(r), last(r))
333+
_range_length(start, s, stop) = missing
334+
@inline function _range_length(start::Int, s::Int, stop::Int)
335+
if s > 0
336+
if stop < start # isempty
337+
return 0
338+
else
339+
return Int(div(stop - start, s)) + 1
340+
end
362341
else
363-
return Base.checked_add(Int(div(Base.checked_sub(start, stop), -step)), 1)
342+
if stop > start # isempty
343+
return 0
344+
else
345+
return Int(div(start - stop, -s)) + 1
346+
end
364347
end
365348
end
366-
_range_length(start, step, stop) = missing
367349

368350
Base.AbstractUnitRange{Int}(r::OptionallyStaticUnitRange) = r
369351
function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T}

src/size.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ julia> ArrayInterface.size(A)
1616
(static(3), static(4))
1717
```
1818
"""
19-
@inline size(A) = map(static_length, axes(A))
19+
size(a::A) where {A} = _maybe_size(Base.IteratorSize(A), a)
20+
_maybe_size(::Base.HasShape{N}, a::A) where {N,A} = map(static_length, axes(a))
21+
_maybe_size(::Base.HasLength, a::A) where {A} = (static_length(a),)
2022
size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
2123
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))
2224
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@@ -40,6 +42,14 @@ function size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A}
4042
end
4143
size(A::ReshapedArray) = Base.size(A)
4244
size(A::AbstractRange) = (static_length(A),)
45+
size(x::Base.Generator) = size(getfield(x, :iter))
46+
size(x::Iterators.Reverse) = size(getfield(x, :itr))
47+
size(x::Iterators.Enumerate) = size(getfield(x, :itr))
48+
size(x::Iterators.Accumulate) = size(getfield(x, :itr))
49+
size(x::Iterators.Pairs) = size(getfield(x, :itr))
50+
@inline function size(x::Iterators.ProductIterator)
51+
eachop(_sub_size, nstatic(Val(ndims(x))), getfield(x, :iterators))
52+
end
4353

4454
size(a, dim) = size(a, to_dims(a, dim))
4555
size(a::Array, dim::Integer) = Base.arraysize(a, convert(Int, dim))
@@ -63,6 +73,7 @@ function size(A::SubArray, dim::Integer)
6373
return static_length(A.indices[pdim])
6474
end
6575
end
76+
size(x::Iterators.Zip) = Static.reduce_tup(promote_shape, map(size, getfield(x, :is)))
6677

6778
"""
6879
known_size(::Type{T}) -> Tuple
@@ -73,7 +84,34 @@ compile time. If a dimension does not have a known size along a dimension then `
7384
returned in its position.
7485
"""
7586
known_size(x) = known_size(typeof(x))
76-
known_size(::Type{T}) where {T} = eachop(_known_size, nstatic(Val(ndims(T))), axes_types(T))
87+
function known_size(::Type{T}) where {T<:AbstractRange}
88+
(_range_length(known_first(T), known_step(T), known_last(T)),)
89+
end
90+
known_size(::Type{<:Base.Generator{I}}) where {I} = known_size(I)
91+
known_size(::Type{<:Iterators.Reverse{I}}) where {I} = known_size(I)
92+
known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
93+
known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I)
94+
known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I)
95+
@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T}
96+
eachop(_known_size, nstatic(Val(known_length(T))), T)
97+
end
98+
99+
# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,
100+
# but we assume as much b/c otherwise it will error while iterating. So we promote to the
101+
# known size if matching a `Missing` and `Int` size.
102+
# 2. `promote_shape(::Tuple{Vararg{CanonicalInt}}, ::Tuple{Vararg{CanonicalInt}})` promotes
103+
# trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to
104+
# `Missing` and `Int` types, so we do one last pass to ensure everything is dynamic
105+
@inline function known_size(::Type{<:Iterators.Zip{T}}) where {T}
106+
dynamic(reduce_tup(_promote_shape, eachop(_unzip_size, nstatic(Val(known_length(T))), T)))
107+
end
108+
_unzip_size(::Type{T}, n::StaticInt{N}) where {T,N} = known_size(field_type(T, n))
109+
110+
known_size(::Type{T}) where {T} = _maybe_known_size(Base.IteratorSize(T), T)
111+
function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
112+
eachop(_known_size, nstatic(Val(N)), axes_types(T))
113+
end
114+
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
77115
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, dim))
78116
@inline known_size(x, dim) = known_size(typeof(x), dim)
79117
@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim))
@@ -84,3 +122,4 @@ _known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, di
84122
return known_size(T)[dim]
85123
end
86124
end
125+

test/runtests.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,11 @@ end
479479
A = zeros(3, 4, 5);
480480
A[:] = 1:60
481481
Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])';
482-
S = @SArray zeros(2,3,4); Sp = @view(PermutedDimsArray(S,(3,1,2))[2:3,1:2,:]);
483-
M = @MArray zeros(2,3,4); Mp = @view(PermutedDimsArray(M,(3,1,2))[:,2,:])';
482+
S = @SArray zeros(2,3,4)
483+
A_trailingdim = zeros(2,3,4,1)
484+
Sp = @view(PermutedDimsArray(S,(3,1,2))[2:3,1:2,:]);
485+
M = @MArray zeros(2,3,4)
486+
Mp = @view(PermutedDimsArray(M,(3,1,2))[:,2,:])';
484487
Sp2 = @view(PermutedDimsArray(S,(3,2,1))[2:3,:,:]);
485488
Mp2 = @view(PermutedDimsArray(M,(3,1,2))[2:3,:,2])';
486489
D = @view(A[:,2:2:4,:]);
@@ -490,6 +493,16 @@ end
490493
A2 = zeros(4, 3, 5)
491494
A2r = reinterpret(ComplexF64, A2)
492495

496+
irev = Iterators.reverse(S)
497+
igen = Iterators.map(identity, S)
498+
iacc = Iterators.accumulate(+, S)
499+
iprod = Iterators.product(axes(S)...)
500+
iflat = Iterators.flatten(iprod)
501+
ienum = enumerate(S)
502+
ipairs = pairs(S)
503+
izip = zip(S,S)
504+
505+
493506
sv5 = @SVector(zeros(5)); v5 = Vector{Float64}(undef, 5);
494507
@test @inferred(ArrayInterface.size(sv5)) === (StaticInt(5),)
495508
@test @inferred(ArrayInterface.size(v5)) === (5,)
@@ -503,6 +516,16 @@ end
503516
@test @inferred(ArrayInterface.size(A2)) === (4,3,5)
504517
@test @inferred(ArrayInterface.size(A2r)) === (2,3,5)
505518

519+
@test @inferred(ArrayInterface.size(irev)) === (StaticInt(2), StaticInt(3), StaticInt(4))
520+
@test @inferred(ArrayInterface.size(iprod)) === (StaticInt(2), StaticInt(3), StaticInt(4))
521+
@test @inferred(ArrayInterface.size(iflat)) === (static(72),)
522+
@test @inferred(ArrayInterface.size(igen)) === (StaticInt(2), StaticInt(3), StaticInt(4))
523+
@test @inferred(ArrayInterface.size(iacc)) === (StaticInt(2), StaticInt(3), StaticInt(4))
524+
@test @inferred(ArrayInterface.size(ienum)) === (StaticInt(2), StaticInt(3), StaticInt(4))
525+
@test @inferred(ArrayInterface.size(ipairs)) === (StaticInt(2), StaticInt(3), StaticInt(4))
526+
@test @inferred(ArrayInterface.size(izip)) === (StaticInt(2), StaticInt(3), StaticInt(4))
527+
@test @inferred(ArrayInterface.size(zip(S, A_trailingdim))) === (StaticInt(2), StaticInt(3), StaticInt(4), static(1))
528+
@test @inferred(ArrayInterface.size(zip(A_trailingdim, S))) === (StaticInt(2), StaticInt(3), StaticInt(4), static(1))
506529
@test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4))
507530
@test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3))
508531
@test @inferred(ArrayInterface.size(Sp2)) === (2, StaticInt(3), StaticInt(2))
@@ -536,6 +559,18 @@ end
536559
@test @inferred(ArrayInterface.known_size(A2)) === (missing, missing, missing)
537560
@test @inferred(ArrayInterface.known_size(A2r)) === (missing, missing, missing)
538561

562+
@test @inferred(ArrayInterface.known_size(irev)) === (2, 3, 4)
563+
@test @inferred(ArrayInterface.known_size(igen)) === (2, 3, 4)
564+
@test @inferred(ArrayInterface.known_size(iprod)) === (2, 3, 4)
565+
@test @inferred(ArrayInterface.known_size(iflat)) === (72,)
566+
@test @inferred(ArrayInterface.known_size(iacc)) === (2, 3, 4)
567+
@test @inferred(ArrayInterface.known_size(ienum)) === (2, 3, 4)
568+
@test @inferred(ArrayInterface.known_size(izip)) === (2, 3, 4)
569+
@test @inferred(ArrayInterface.known_size(ipairs)) === (2, 3, 4)
570+
@test @inferred(ArrayInterface.known_size(zip(S, A_trailingdim))) === (2, 3, 4, 1)
571+
@test @inferred(ArrayInterface.known_size(zip(A_trailingdim, S))) === (2, 3, 4, 1)
572+
@test @inferred(ArrayInterface.known_length(Iterators.flatten(((x,y) for x in 0:1 for y in 'a':'c')))) === missing
573+
539574
@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
540575
@test @inferred(ArrayInterface.known_size(Wrapper(S))) === (2, 3, 4)
541576
@test @inferred(ArrayInterface.known_size(Sp)) === (missing, missing, 3)
@@ -544,7 +579,6 @@ end
544579
@test @inferred(ArrayInterface.known_size(Sp2, StaticInt(1))) === missing
545580
@test @inferred(ArrayInterface.known_size(Sp2, StaticInt(2))) === 3
546581
@test @inferred(ArrayInterface.known_size(Sp2, StaticInt(3))) === 2
547-
548582
@test @inferred(ArrayInterface.known_size(M)) === (2, 3, 4)
549583
@test @inferred(ArrayInterface.known_size(Mp)) === (3, 4)
550584
@test @inferred(ArrayInterface.known_size(Mp2)) === (2, missing)

0 commit comments

Comments
 (0)