Skip to content

Commit d497de8

Browse files
committed
coax inference
1 parent 6148a9f commit d497de8

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/ArrayInterface.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,27 @@ known_length(::Type{<:Number}) = 1
9494
known_length(::Type{<:AbstractCartesianIndex{N}}) where {N} = N
9595
known_length(::Type{T}) where {T} = _maybe_known_length(Base.IteratorSize(T), T)
9696

97-
@inline _prod_or_nothing(x, ::Tuple{}) = x
98-
@inline _prod_or_nothing(_, ::Tuple{Nothing,Vararg}) = nothing
99-
@inline _prod_or_nothing(x, y::Tuple{I,Vararg}) where {I} = _prod_or_nothing(x*first(y), Base.tail(y))
97+
@generated function _prod_or_nothing(x::Tuple)
98+
p = 1
99+
for i in eachindex(x.parameters)
100+
x.parameters[i] === Nothing && return nothing
101+
p *= x.parameters[i].parameters[1]
102+
end
103+
StaticInt(p)
104+
end
105+
106+
function _maybe_known_length(::Base.HasShape, ::Type{T}) where {T}
107+
t = map(_static_or_nothing, known_size(T))
108+
_int_or_nothing(_prod_or_nothing(t))
109+
end
100110

101-
_maybe_known_length(::Base.HasShape, ::Type{T}) where {T} = _prod_or_nothing(1, known_size(T))
102111
_maybe_known_length(::Base.IteratorSize, ::Type) = nothing
112+
_static_or_nothing(::Nothing) = nothing
113+
@inline _static_or_nothing(x::Int) = StaticInt{x}()
114+
_int_or_nothing(::StaticInt{N}) where {N} = N
115+
_int_or_nothing(::Nothing) = nothing
103116
function known_length(::Type{<:Iterators.Flatten{I}}) where {I}
104-
_prod_or_nothing(1, (known_length(I),known_length(eltype(I))))
117+
_int_or_nothing(_prod_or_nothing((_static_or_nothing(known_length(I)),_static_or_nothing(known_length(eltype(I))))))
105118
end
106119

107120
"""

test/runtests.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,7 @@ end
518518

519519
@test @inferred(ArrayInterface.size(irev)) === (StaticInt(2), StaticInt(3), StaticInt(4))
520520
@test @inferred(ArrayInterface.size(iprod)) === (StaticInt(2), StaticInt(3), StaticInt(4))
521-
if VERSION >= v"1.7"
522-
@test @inferred(ArrayInterface.size(iflat)) === (static(72),)
523-
else
524-
@test_skip @inferred(ArrayInterface.size(iflat)) === (static(72),)
525-
end
521+
@test @inferred(ArrayInterface.size(iflat)) === (static(72),)
526522
@test @inferred(ArrayInterface.size(igen)) === (StaticInt(2), StaticInt(3), StaticInt(4))
527523
@test @inferred(ArrayInterface.size(iacc)) === (StaticInt(2), StaticInt(3), StaticInt(4))
528524
@test @inferred(ArrayInterface.size(ienum)) === (StaticInt(2), StaticInt(3), StaticInt(4))

0 commit comments

Comments
 (0)