From 7855f197c49831ef9fffec2620590961baf61116 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 17 Aug 2025 09:56:56 +0530 Subject: [PATCH 1/3] feat: allow `mapreduce` overlay to accept multiple buffers --- src/Overlay.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 43c24ae817..f3def14575 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -156,12 +156,12 @@ end end @reactant_overlay @noinline function Base.mapreduce( - f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs... + f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}...; kwargs... ) if use_overlayed_version(A) - return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) + return TracedRArrayOverrides.overloaded_mapreduce(f, op, A...; kwargs...) else - return Base.inferencebarrier(Base.mapreduce)(f, op, A; kwargs...) + return Base.inferencebarrier(Base.mapreduce)(f, op, A...; kwargs...) end end From 43b3e63fac1e9140f0fc88f644b18b54ac7b361f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 17 Aug 2025 10:30:54 +0530 Subject: [PATCH 2/3] feat: handle multiple buffers in `overloaded_mapreduce` --- src/TracedRArray.jl | 95 ++++++++++++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index af0a7aac62..ddf0864903 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -587,42 +587,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F} return T(__default_init(Float16, op)) end -function overloaded_mapreduce( - @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing -) - res = unwrapped_broadcast(f, A) - # This means we are unable to use the optimized dispatches. For now we will - # unroll the mapreduce. - if typeof(res) == typeof(A) - @assert dims == Colon() "dims not supported for mapreduce currently." - return foldl(op, res; init) - end - return overloaded_mapreduce(identity, op, res; dims=:, init) -end +_maybe_materialize_traced_array(x::AbstractArray) = materialize_traced_array(x) +_maybe_materialize_traced_array(x) = x + +_change_traced_type(::Type{T}, x::AnyTracedRArray) where {T} = T.(x) +_change_traced_type(::Type{T}, x) where {T} = x function overloaded_mapreduce( @nospecialize(f), @nospecialize(op), - @nospecialize(A::AnyTracedRArray{T,N}); + @nospecialize(A...); dims=:, init=nothing, -) where {T,N} - A = materialize_traced_array(A) +) + if all(x -> !(x isa AnyTracedRArray), A) + res = unwrapped_broadcast(f, A...) + # This means we are unable to use the optimized dispatches. For now we will + # unroll the mapreduce. + if typeof(res) == typeof(A[1]) + @assert dims == Colon() "dims not supported for mapreduce currently." + return foldl(op, res; init) + end + return overloaded_mapreduce(identity, op, res; dims=:, init) + end + + A = _maybe_materialize_traced_array.(A) + mapped_shape = allequal(map(size, A)) ? size(A[1]) : (minimum(length, A),) + N = length(mapped_shape) + A = map(x -> reshape(x, length(x)), A) original_dims = dims dims isa Int && (dims = Int64[dims]) dims isa Colon && (dims = collect(Int64, 1:N)) dims isa Vector{Int64} || (dims = collect(Int64, dims)) - op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T})) + op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Broadcast.eltypes(A))) reduce_init = __default_init(op_in_T, op) if unwrapped_eltype(typeof(reduce_init)) != op_in_T op_in_T = typeof(reduce_init) - A = typeof(reduce_init).(A) + A = _change_traced_type.(typeof(reduce_init), A) end reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init) - reduce_input = materialize_traced_array(broadcast(f, A)) + res = reshape(f.(A...), mapped_shape) + if !(res isa AnyTracedRArray) + @assert dims == Colon() "dims not supported for mapreduce currently." + return foldl(op, res; init) + end + + reduce_input = materialize_traced_array(res) res = @opcall reduce(reduce_input, reduce_init, dims, op) @@ -635,7 +648,7 @@ function overloaded_mapreduce( if res isa TracedRNumber res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ()) end - return @opcall reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N]) + return @opcall reshape(res, [ifelse(i in dims, 1, mapped_shape[i]) for i in 1:N]) end function Base.mapreducedim!( @@ -789,7 +802,6 @@ function _copyto!(dest::Array{<:TracedRNumber}, bc::Broadcasted) bc = Broadcast.preprocess(dest, bc) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) for I in 1:length(dest) dest[I] = Reactant.@allowscalar res[I] @@ -1460,25 +1472,46 @@ end (fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,)) -function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F} +function _canonicalize_iter(x::Base.Iterators.Zip) min_length = Base.inferencebarrier(minimum)(length, x.is) - itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is] - if any(Base.Fix2(isa, AnyTracedRArray), itrs) - return (BroadcastIterator(f)).(itrs...) + iters = last.(_canonicalize_iter.(x.is)) + itrs = [Base.Fix2(getindex, i).(iters) for i in 1:min_length] + any_is_anytraced = any(Base.Fix2(isa, AnyTracedRArray), x.is) + return min_length, any_is_anytraced, itrs +end + +function _canonicalize_iter(x::Base.Iterators.Enumerate) + return _canonicalize_iter(zip(eachindex(x), x)) +end + +_canonicalize_iter(x) = length(x), x isa AnyTracedRArray, x + +function unwrapped_broadcast(f::F, xs...) where {F} + len, any_is_anytraced, itrs = if length(xs) == 1 + _canonicalize_iter(xs[1]) else - fn = BroadcastIterator(f) - return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length] + _canonicalize_iter(zip(xs...)) + end + fn = BroadcastIterator(f) + if any_is_anytraced + return splat(f).(itrs) + else + return [fn(x...) for x in itrs] end end -function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F} - if x.itr isa AnyTracedRArray - return (BroadcastIterator(f)).(1:length(x.itr), x.itr) +function unwrapped_broadcast(f::F, xs::Union{Base.Iterators.Zip, Base.Iterators.Enumerate}) where {F} + len, any_is_anytraced, itrs = _canonicalize_iter(xs) + fn = BroadcastIterator(f) + if any_is_anytraced + return splat(f).(itrs) else - return [f((i, x.itr[i])) for i in 1:length(x.itr)] + return [fn(x...) for x in itrs] end end -unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs] +function unwrapped_broadcast(f::F, xs) where {F} + [f(x) for x in xs] +end end From 8a5189b020ebd36340d97de0b11fa51b8b187b1c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 00:38:02 +0530 Subject: [PATCH 3/3] test: add tests for multi-array `mapreduce` --- test/basic.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 8359b053de..3d63343c9a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1431,6 +1431,8 @@ end end zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b)) +nary_mapreduce(a, b) = mapreduce(*, +, a, b) +nary_mapreduce_dims(a, b) = mapreduce(*, +, a, b; dims = 2) enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a)) function nested_mapreduce_zip(x, y) @@ -1483,6 +1485,20 @@ end @test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y) end + + @testset "n-ary mapreduce" begin + x = rand(Float32, 12) + y = rand(Float32, 12) + z = rand(Float32, 4, 3) + w = rand(Float32, 4, 3) + + rx, ry, rz, rw = Reactant.to_rarray.((x, y, z, w)) + @test @jit(nary_mapreduce(rx, ry)) ≈ nary_mapreduce(x, y) + @test @jit(nary_mapreduce(rx, rz)) ≈ nary_mapreduce(x, z) + @test @jit(nary_mapreduce(rz, rw)) ≈ nary_mapreduce(z, w) + @test @jit(nary_mapreduce_dims(rz, rw)) ≈ nary_mapreduce_dims(z, w) + @test @jit(nary_mapreduce(rz, rx)) ≈ nary_mapreduce(z, x) + end end @testset "compilation cache" begin