diff --git a/src/comparison.jl b/src/comparison.jl index 8d1eba87..a5690f6e 100644 --- a/src/comparison.jl +++ b/src/comparison.jl @@ -459,11 +459,6 @@ struct ExponentsIterator{M,D<:Union{Nothing,Int},O} ), ) end - if length(object) == 0 && isnothing(maxdegree) - # Otherwise, it will incorrectly think that the iterator is infinite - # while it actually has zero elements - maxdegree = mindegree - end return new{M,typeof(maxdegree),typeof(object)}( object, mindegree, @@ -474,24 +469,49 @@ struct ExponentsIterator{M,D<:Union{Nothing,Int},O} end Base.eltype(::Type{ExponentsIterator{M,D,O}}) where {M,D,O} = O +# `IteratorSize` returns something different depending on whether it is called +# in an instance or on the type. `Iterators.Cycle` has the same behavior, +# see https://github.com/JuliaLang/julia/pull/54187 function Base.IteratorSize(::Type{<:ExponentsIterator{M,Nothing}}) where {M} + # It could be `HasLength` is `IsInfinite` depending on whether `it.object` + # is empty so the size is unknown when we only have access to the type of + # the iterator. The method below gives the correct size when we have access + # to the instance. + return Base.SizeUnknown() +end +function Base.IteratorSize(it::ExponentsIterator{M,Nothing}) where {M} + if isempty(it.object) + return Base.HasLength() + end return Base.IsInfinite() end function Base.IteratorSize(::Type{<:ExponentsIterator{M,Int}}) where {M} return Base.HasLength() end -function Base.length(it::ExponentsIterator{M,Int}) where {M} - if it.maxdegree < it.mindegree +function _length(it::ExponentsIterator, maxdegree) + if maxdegree < it.mindegree return 0 end - len = binomial(nvariables(it) + it.maxdegree, nvariables(it)) + len = binomial(nvariables(it) + maxdegree, nvariables(it)) if it.mindegree > 0 len -= binomial(nvariables(it) + it.mindegree - 1, nvariables(it)) end return len end +function Base.length(it::ExponentsIterator{M,Int}) where {M} + return _length(it, it.maxdegree) +end + +function Base.length(it::ExponentsIterator{M,Nothing}) where {M} + if isempty(it.object) + return _length(it, it.mindegree) + else + error("The iterator is infinity because `maxdegree` is `nothing`.") + end +end + nvariables(it::ExponentsIterator) = length(it.object) _last_lex_index(n, ::Type{LexOrder}) = n diff --git a/test/comparison.jl b/test/comparison.jl index bab42246..439a4a35 100644 --- a/test/comparison.jl +++ b/test/comparison.jl @@ -26,6 +26,11 @@ function test_errors() "Ordering `$M` is not a valid ordering, use `Graded{$M}` instead.", ) @test_throws err ExponentsIterator{M}([0], maxdegree = 2) + exps = ExponentsIterator{LexOrder}([0]) + err = ErrorException( + "The iterator is infinity because `maxdegree` is `nothing`.", + ) + @test_throws err length(exps) end function test_exponents_iterator()