Skip to content

Commit 8421c03

Browse files
N5N3KristofferC
authored andcommitted
Improve foldl's stability on nested Iterators (#45789)
* Make `Fix1(f, Int)` inference-stable * split `_xfadjoint` into `_xfadjoint_unwrap` and `_xfadjoint_wrap` * Improve `(c::ComposedFunction)(x...)`'s inferability * and fuse it in `Base._xfadjoint`. * define a `Typeof` operator that will partly work around internal type-system bugs Closes #45715 (cherry picked from commit d58289c)
1 parent 98efbdf commit 8421c03

File tree

4 files changed

+65
-17
lines changed

4 files changed

+65
-17
lines changed

base/operators.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,9 @@ julia> [1:5;] |> (x->x.^2) |> sum |> inv
910910
"""
911911
|>(x, f) = f(x)
912912

913+
_stable_typeof(x) = typeof(x)
914+
_stable_typeof(::Type{T}) where {T} = @isdefined(T) ? Type{T} : DataType
915+
913916
"""
914917
f = Returns(value)
915918
@@ -936,7 +939,7 @@ julia> f.value
936939
struct Returns{V} <: Function
937940
value::V
938941
Returns{V}(value) where {V} = new{V}(value)
939-
Returns(value) = new{Core.Typeof(value)}(value)
942+
Returns(value) = new{_stable_typeof(value)}(value)
940943
end
941944

942945
(obj::Returns)(args...; kw...) = obj.value
@@ -1027,7 +1030,19 @@ struct ComposedFunction{O,I} <: Function
10271030
ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner)
10281031
end
10291032

1030-
(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...))
1033+
function (c::ComposedFunction)(x...; kw...)
1034+
fs = unwrap_composed(c)
1035+
call_composed(fs[1](x...; kw...), tail(fs)...)
1036+
end
1037+
unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...)
1038+
unwrap_composed(c) = (maybeconstructor(c),)
1039+
call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...))
1040+
call_composed(x, f) = f(x)
1041+
1042+
struct Constructor{F} <: Function end
1043+
(::Constructor{F})(args...; kw...) where {F} = (@inline; F(args...; kw...))
1044+
maybeconstructor(::Type{F}) where {F} = Constructor{F}()
1045+
maybeconstructor(f) = f
10311046

10321047
(f) = f
10331048
(f, g) = ComposedFunction(f, g)
@@ -1074,8 +1089,8 @@ struct Fix1{F,T} <: Function
10741089
f::F
10751090
x::T
10761091

1077-
Fix1(f::F, x::T) where {F,T} = new{F,T}(f, x)
1078-
Fix1(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
1092+
Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
1093+
Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
10791094
end
10801095

10811096
(f::Fix1)(y) = f.f(f.x, y)
@@ -1091,8 +1106,8 @@ struct Fix2{F,T} <: Function
10911106
f::F
10921107
x::T
10931108

1094-
Fix2(f::F, x::T) where {F,T} = new{F,T}(f, x)
1095-
Fix2(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x)
1109+
Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x)
1110+
Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x)
10961111
end
10971112

10981113
(f::Fix2)(y) = f.f(y, f.x)

base/reduce.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,25 @@ what is returned is `itr′` and
140140
141141
op′ = (xfₙ ∘ ... ∘ xf₂ ∘ xf₁)(op)
142142
"""
143-
_xfadjoint(op, itr) = (op, itr)
144-
_xfadjoint(op, itr::Generator) =
145-
if itr.f === identity
146-
_xfadjoint(op, itr.iter)
147-
else
148-
_xfadjoint(MappingRF(itr.f, op), itr.iter)
149-
end
150-
_xfadjoint(op, itr::Filter) =
151-
_xfadjoint(FilteringRF(itr.flt, op), itr.itr)
152-
_xfadjoint(op, itr::Flatten) =
153-
_xfadjoint(FlatteningRF(op), itr.it)
143+
function _xfadjoint(op, itr)
144+
itr′, wrap = _xfadjoint_unwrap(itr)
145+
wrap(op), itr′
146+
end
147+
148+
_xfadjoint_unwrap(itr) = itr, identity
149+
function _xfadjoint_unwrap(itr::Generator)
150+
itr′, wrap = _xfadjoint_unwrap(itr.iter)
151+
itr.f === identity && return itr′, wrap
152+
return itr′, wrap Fix1(MappingRF, itr.f)
153+
end
154+
function _xfadjoint_unwrap(itr::Filter)
155+
itr′, wrap = _xfadjoint_unwrap(itr.itr)
156+
return itr′, wrap Fix1(FilteringRF, itr.flt)
157+
end
158+
function _xfadjoint_unwrap(itr::Flatten)
159+
itr′, wrap = _xfadjoint_unwrap(itr.it)
160+
return itr′, wrap FlatteningRF
161+
end
154162

155163
"""
156164
mapfoldl(f, op, itr; [init])

test/operators.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,15 @@ Base.promote_rule(::Type{T19714}, ::Type{Int}) = T19714
175175

176176
end
177177

178+
@testset "Nested ComposedFunction's stability" begin
179+
f(x) = (1, 1, x...)
180+
g = (f (f f)) (f f f)
181+
@test (@inferred (gg)(1)) == ntuple(Returns(1), 25)
182+
@test (@inferred g(1)) == ntuple(Returns(1), 13)
183+
h = (-) (-) (-) (-) (-) (-) sum
184+
@test (@inferred h((1, 2, 3); init = 0.0)) == 6.0
185+
end
186+
178187
@testset "function negation" begin
179188
str = randstring(20)
180189
@test filter(!isuppercase, str) == replace(str, r"[A-Z]" => "")
@@ -302,6 +311,9 @@ end
302311
val = [1,2,3]
303312
@test Returns(val)(1) === val
304313
@test sprint(show, Returns(1.0)) == "Returns{Float64}(1.0)"
314+
315+
illtype = Vector{Core._typevar(:T, Union{}, Any)}
316+
@test Returns(illtype) == Returns{DataType}(illtype)
305317
end
306318

307319
@testset "<= (issue #46327)" begin

test/reduce.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,16 @@ end
677677
@test mapreduce(+, +, oa, oa) == 2len
678678
end
679679
end
680+
681+
# issue #45748
682+
@testset "foldl's stability for nested Iterators" begin
683+
a = Iterators.flatten((1:3, 1:3))
684+
b = (2i for i in a if i > 0)
685+
c = Base.Generator(Float64, b)
686+
d = (sin(i) for i in c if i > 0)
687+
@test @inferred(sum(d)) == sum(collect(d))
688+
@test @inferred(extrema(d)) == extrema(collect(d))
689+
@test @inferred(maximum(c)) == maximum(collect(c))
690+
@test @inferred(prod(b)) == prod(collect(b))
691+
@test @inferred(minimum(a)) == minimum(collect(a))
692+
end

0 commit comments

Comments
 (0)