diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 83ce200a4e..3ff81b6fe3 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -550,6 +550,12 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F} return T(__default_init(Float16, op)) end +struct TracedCall{F} <: Function + f::F +end + +(fn::TracedCall)(args...) = @opcall call(fn.f, args...) + function overloaded_mapreduce( @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue() ) @@ -558,7 +564,7 @@ function overloaded_mapreduce( # unroll the mapreduce. if typeof(res) == typeof(A) @assert dims == Colon() "dims not supported for mapreduce currently." - return foldl(op, res; init) + return foldl(TracedCall(op), res; init) end return overloaded_mapreduce(identity, op, res; dims=:, init) end @@ -1503,6 +1509,9 @@ end struct BroadcastIterator{F} f::F + + BroadcastIterator{F}(f::F) where {F} = new{F}(f) + BroadcastIterator(f::F) where {F} = new{F}(f) end (fn::BroadcastIterator)(args...) = fn.f((args...,)) @@ -1523,14 +1532,13 @@ end unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs) -# TODO: once traced_call supports internal mutations, we can use traced_call here # TODO: we should overload this for Slices and use mapslices instead function unrolled_map(f::F, itr) where {F} y = Reactant.call_with_reactant(iterate, itr) y === nothing && return [] first, state = y - res_first = Reactant.call_with_reactant(f, first) + res_first = @opcall call(f, first) result = [res_first] while true @@ -1538,7 +1546,7 @@ function unrolled_map(f::F, itr) where {F} y === nothing && break val, state = y - res = Reactant.call_with_reactant(f, val) + res = @opcall call(f, val) push!(result, res) end diff --git a/test/basic.jl b/test/basic.jl index e8e03cab70..d1976f1d02 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1496,6 +1496,15 @@ end @test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y) end + + @testset "mapreduce vector" begin + x = [rand(Float32, 2, 3) for _ in 1:10] + x_ra = Reactant.to_rarray(x) + + @test @jit(mapreduce_vector(x_ra)) ≈ mapreduce_vector(x) + hlo = repr(@code_hlo optimize = false mapreduce_vector(x_ra)) + @test contains(hlo, "call") + end end @testset "Base.Generator" begin diff --git a/test/integration/fillarrays.jl b/test/integration/fillarrays.jl index f7d7359fe6..76ffad570a 100644 --- a/test/integration/fillarrays.jl +++ b/test/integration/fillarrays.jl @@ -25,5 +25,5 @@ end x = OneElement(3.4f0, (3, 4), (32, 32)) rx = Reactant.to_rarray(x) - @test @jit(fn(rx, rx)) ≈ fn(x, x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) atol = 1e-3 rtol = 1e-3 end