From 7f15e652251f968bacbcd984e255fad546ef994c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Sep 2025 17:20:40 -0400 Subject: [PATCH 1/3] feat: tracing call + correct init value --- src/TracedRArray.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 83ce200a4e..1121291d15 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -550,6 +550,12 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F} return T(__default_init(Float16, op)) end +struct TracedCall{F} <: Function + f::F +end + +(fn::TracedCall)(args...) = @opcall call(fn.f, args...) + function overloaded_mapreduce( @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue() ) @@ -558,7 +564,7 @@ function overloaded_mapreduce( # unroll the mapreduce. if typeof(res) == typeof(A) @assert dims == Colon() "dims not supported for mapreduce currently." - return foldl(op, res; init) + return foldl(TracedCall(op), res; init) end return overloaded_mapreduce(identity, op, res; dims=:, init) end From 5fcf4272958c2cc934aa64974851ce6e0e0076f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Sep 2025 17:26:54 -0400 Subject: [PATCH 2/3] feat: split out non-generator changes from #1642 --- src/Reactant.jl | 1 + src/TracedRArray.jl | 7 +++++-- test/basic.jl | 9 +++++++++ test/integration/fillarrays.jl | 2 +- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index da18116d15..8ae11cf814 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -193,6 +193,7 @@ use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false use_overlayed_version(::MissingTracedValue) = true +use_overlayed_version(::Vector{<:AnyTracedRArray}) = true use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed) function use_overlayed_version(x::AbstractArray) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 1121291d15..3be969fcef 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1509,6 +1509,9 @@ end struct BroadcastIterator{F} f::F + + BroadcastIterator{F}(f::F) where {F} = new{F}(f) + BroadcastIterator(f::F) where {F} = new{F}(f) end (fn::BroadcastIterator)(args...) = fn.f((args...,)) @@ -1536,7 +1539,7 @@ function unrolled_map(f::F, itr) where {F} y === nothing && return [] first, state = y - res_first = Reactant.call_with_reactant(f, first) + res_first = @opcall call(f, first) result = [res_first] while true @@ -1544,7 +1547,7 @@ function unrolled_map(f::F, itr) where {F} y === nothing && break val, state = y - res = Reactant.call_with_reactant(f, val) + res = @opcall call(f, val) push!(result, res) end diff --git a/test/basic.jl b/test/basic.jl index e8e03cab70..d1976f1d02 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1496,6 +1496,15 @@ end @test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y) end + + @testset "mapreduce vector" begin + x = [rand(Float32, 2, 3) for _ in 1:10] + x_ra = Reactant.to_rarray(x) + + @test @jit(mapreduce_vector(x_ra)) ≈ mapreduce_vector(x) + hlo = repr(@code_hlo optimize = false mapreduce_vector(x_ra)) + @test contains(hlo, "call") + end end @testset "Base.Generator" begin diff --git a/test/integration/fillarrays.jl b/test/integration/fillarrays.jl index f7d7359fe6..76ffad570a 100644 --- a/test/integration/fillarrays.jl +++ b/test/integration/fillarrays.jl @@ -25,5 +25,5 @@ end x = OneElement(3.4f0, (3, 4), (32, 32)) rx = Reactant.to_rarray(x) - @test @jit(fn(rx, rx)) ≈ fn(x, x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) atol = 1e-3 rtol = 1e-3 end From 8a5ef1d6072f55da4a5ebb72c2b9d086378776e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Sep 2025 10:13:05 -0400 Subject: [PATCH 3/3] feat: use `traced_call` when unrolling iterators and generators (#1642) * feat: better support for Base.Generators * feat: use traced_call when unrolling iterators and generators * fix: closure with call working * fix: try removing nospecialize * fix: use a looped version of any to avoid inference issues * fix: dont overlay inside compiler call --- src/Reactant.jl | 1 - src/TracedRArray.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 8ae11cf814..da18116d15 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -193,7 +193,6 @@ use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false use_overlayed_version(::MissingTracedValue) = true -use_overlayed_version(::Vector{<:AnyTracedRArray}) = true use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed) function use_overlayed_version(x::AbstractArray) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 3be969fcef..3ff81b6fe3 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1532,7 +1532,6 @@ end unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs) -# TODO: once traced_call supports internal mutations, we can use traced_call here # TODO: we should overload this for Slices and use mapslices instead function unrolled_map(f::F, itr) where {F} y = Reactant.call_with_reactant(iterate, itr)