diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 819c43506c..cb9004c341 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -650,7 +650,7 @@ function Base.mapreduce( @nospecialize(op), @nospecialize(A::AbstractConcreteArray{T,N}); dims=:, - init=nothing, + init=Base._InitialValue(), ) where {T,N} fn = compile(CallMapReduce(f, op, dims, init), (A,)) return fn(A) diff --git a/src/Ops.jl b/src/Ops.jl index 11ba04d9df..c3a04a19b9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -46,7 +46,11 @@ macro opcall(expr) # Generate location info at the callsite location_expr = :($(mlir_stacktrace)( - joinpath(string(var"#self#"), $(string(func))), + if @isdefined(var"#self#") + joinpath(string(var"#self#"), $(string(func))) + else + $(string(func)) + end, $(string(__source__.file)), $(__source__.line), )) @@ -2575,7 +2579,7 @@ end seen_cache = Reactant.OrderedIdDict() Reactant.make_tracer( seen_cache, - args, + fnwrapped ? (f, args) : args, (), # we have to insert something here, but we remove it immediately below. Reactant.TracedTrack; toscalar=false, diff --git a/src/Overlay.jl b/src/Overlay.jl index ec92063dd7..951ad3c7c1 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -156,7 +156,10 @@ end end @reactant_overlay @noinline function Base.mapreduce( - f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs... + f, + op, + A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator}; + kwargs..., ) if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) diff --git a/src/Reactant.jl b/src/Reactant.jl index 87cabf164b..da18116d15 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -60,6 +60,7 @@ function _parent end _parent_type(::Type{Array}) = Array _parent_type(::Type{Array{T}}) where {T} = Array{T} _parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N} +_parent_type(::Type{<:Slices{P}}) where {P} = P include("accelerators/Accelerators.jl") @@ -179,10 +180,15 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") use_overlayed_version(x) = false -use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is) +function use_overlayed_version(x::F) where {F<:Function} + return use_overlayed_version(getfield.(Ref(x), fieldnames(F))) +end +use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter)) +use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is) use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr) -use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter) -use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter)) +use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x) +use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter) +use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter)) use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false @@ -195,6 +201,14 @@ function use_overlayed_version(x::AbstractArray) return use_overlayed_version(a) end +## We avoid calling into `any` to avoid triggering the `any` overlay +function looped_any(f::F, itr) where {F} + @inbounds for x in itr + f(x) && return true + end + return false +end + # StdLib Overloads include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index deaa933d6d..83ce200a4e 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -551,7 +551,7 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F} end function overloaded_mapreduce( - @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing + @nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue() ) res = unwrapped_broadcast(f, A) # This means we are unable to use the optimized dispatches. For now we will @@ -568,7 +568,7 @@ function overloaded_mapreduce( @nospecialize(op), @nospecialize(A::AnyTracedRArray{T,N}); dims=:, - init=nothing, + init=Base._InitialValue(), ) where {T,N} A = materialize_traced_array(A) @@ -589,7 +589,7 @@ function overloaded_mapreduce( res = @opcall reduce(reduce_input, reduce_init, dims, op) - init !== nothing && (res = op.(res, init)) + (init isa Base._InitialValue || init === nothing) || (res = op.(res, init)) if original_dims isa Colon @assert size(res) == () "expected size of result to be (), got $(size(res))" @@ -677,6 +677,8 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) # Special case a union{} return so we can see the better error message if ElType === Union{} fn(map(first_scalar, bc.args)...) + elseif ElType == Any + ElType = eltype(fn(map(first_scalar, bc.args)...)) end @assert ElType != Any && ElType != Union{} sim = similar(bc, ElType) @@ -1231,16 +1233,25 @@ function overloaded_map(f, x::AbstractArray, xs::AbstractArray...) @assert allequal((axes(x), axes.(xs)...)) "Expected axes of all inputs to map to be \ equal" + needs_unrolling = falses(length(xs) + 1) inputs = () - for input in (x, xs...) + for (i, input) in enumerate((x, xs...)) if input isa AnyTracedRArray input = Reactant.materialize_traced_array(input) - else + elseif eltype(input) <: Reactant.ReactantPrimitive input = Reactant.promote_to(TracedRArray{eltype(input),ndims(input)}, input) + else + needs_unrolling[i] = true end inputs = (inputs..., input) end + @assert allequal(needs_unrolling) "All inputs to `overloaded_map` must be \ + unrolled or none of them. Open an issue." + if needs_unrolling[1] + length(inputs) == 1 && return unrolled_map(f, only(inputs)) + return unrolled_map(splat(f), zip(inputs...)) + end return TracedUtils.elem_apply(f, inputs...) end @@ -1321,14 +1332,14 @@ function scan_impl!( output::AnyTracedRArray{T,N}, input::AnyTracedRArray{T,N}; dims::Integer, - init=nothing, + init=Base._InitialValue(), ) where {T,N} @assert dims > 0 "dims must be a positive integer" @assert axes(output) == axes(input) "output and input must have the same shape" dims > ndims(input) && return copyto!(output, input) - if init === nothing + if init isa Base._InitialValue op_in_T = Core.Compiler.return_type(op, Tuple{T,T}) op_in_T === Union{} && (op_in_T = T) init = __default_init(T, op) @@ -1494,27 +1505,44 @@ struct BroadcastIterator{F} f::F end -(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,)) +(fn::BroadcastIterator)(args...) = fn.f((args...,)) function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F} min_length = Base.inferencebarrier(minimum)(length, x.is) itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is] - if any(Base.Fix2(isa, AnyTracedRArray), itrs) - return (BroadcastIterator(f)).(itrs...) - else - fn = BroadcastIterator(f) - return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length] - end + any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x) + return broadcast(BroadcastIterator(f), itrs...) end function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F} - if x.itr isa AnyTracedRArray - return (BroadcastIterator(f)).(1:length(x.itr), x.itr) - else - return [f((i, x.itr[i])) for i in 1:length(x.itr)] - end + x.itr isa AnyTracedRArray || return unrolled_map(f, x) + return broadcast( + BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr + ) end -unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs] +unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs) + +# TODO: once traced_call supports internal mutations, we can use traced_call here +# TODO: we should overload this for Slices and use mapslices instead +function unrolled_map(f::F, itr) where {F} + y = Reactant.call_with_reactant(iterate, itr) + y === nothing && return [] + + first, state = y + res_first = Reactant.call_with_reactant(f, first) + result = [res_first] + + while true + y = Reactant.call_with_reactant(iterate, itr, state) + y === nothing && break + + val, state = y + res = Reactant.call_with_reactant(f, val) + push!(result, res) + end + + return result +end end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 554791af8f..98b83ea006 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -660,7 +660,7 @@ function finalize_mlir_fn( skipped_results = Reactant.TracedType[] for (k, v) in seen_results v isa Reactant.TracedType || continue - if any(Base.Fix1(===, k), skipped_args) + if Reactant.looped_any(Base.Fix1(===, k), skipped_args) push!(skipped_results, v) _, argpath = get_argidx(v, argprefix) diff --git a/src/utils.jl b/src/utils.jl index 0081c4be4b..d4af7bb72d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,4 @@ -struct CallWithReactant{F} +struct CallWithReactant{F} <: Function f::F end diff --git a/test/autodiff.jl b/test/autodiff.jl index 0f6adf70dc..b52891c161 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -132,7 +132,7 @@ end @testset "Forward Gradient" begin x = Reactant.to_rarray(3.1 * ones(2, 2)) - res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x) + res = @jit gw(x) # TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132 # to make sure this gets merged as a tracedrarray @test res isa Tuple{<:Enzyme.TupleArray{<:ConcreteRNumber{Float64},(2, 2),4,2}} diff --git a/test/basic.jl b/test/basic.jl index d174570a89..e8e03cab70 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -926,7 +926,7 @@ end ra = Reactant.to_rarray(x) @jit dip!(ra) - ra[:a] ≈ (2.7 * 2) * ones(4) + @test ra[:a] ≈ (2.7 * 3.1) * ones(4) end @testset "@code_xla" begin @@ -1429,7 +1429,10 @@ end end zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b)) +zip_iterator2(a, b) = mapreduce(splat(.-), +, zip(a, b)) enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a)) +enumerate_iterator2(a) = mapreduce(splat(.-), +, enumerate(a)) +mapreduce_vector(a) = mapreduce(-, +, a) function nested_mapreduce_zip(x, y) return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y) @@ -1445,37 +1448,49 @@ function nested_mapreduce_hcat(x, y) end end +function f_generator(points, params) + return sum(params * point for point in points) +end + @testset "Base.Iterators" begin @testset "zip" begin N = 10 - a = range(1.0, 5.0; length=N) - x = range(10.0, 15.0; length=N + 2) + a = collect(range(1.0, 5.0; length=N)) + x = collect(range(10.0, 15.0; length=N + 2)) x_ra = Reactant.to_rarray(x) @test @jit(zip_iterator(a, x_ra)) ≈ zip_iterator(a, x) + + a = [rand(Float32, 2, 3) for _ in 1:10] + x = [rand(Float32, 2, 3) for _ in 1:10] + a_ra = Reactant.to_rarray(a) + x_ra = Reactant.to_rarray(x) + + @test @jit(zip_iterator2(a_ra, x_ra)) ≈ zip_iterator2(a, x) end @testset "enumerate" begin - x = range(1.0, 5.0; length=10) + x = collect(range(1.0, 5.0; length=10)) x_ra = Reactant.to_rarray(x) @test @jit(enumerate_iterator(x_ra)) ≈ enumerate_iterator(x) + + x = [rand(Float32, 2, 3) for _ in 1:10] + x_ra = Reactant.to_rarray(x) + + @test @jit(enumerate_iterator2(x_ra)) ≈ enumerate_iterator2(x) end @testset "nested mapreduce" begin x = rand(Float32, 4, 3) y = rand(Float32, 4, 3) - x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - @test @jit(nested_mapreduce_zip(x_ra, y_ra)) ≈ nested_mapreduce_zip(x, y) end - @testset "nested mapreduce hcat" begin x = rand(Float32, 4, 3) y = rand(Float32, 4, 3) - x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) @@ -1483,6 +1498,15 @@ end end end +@testset "Base.Generator" begin + points = eachcol(rand(Float32, 2, 6)) + params = rand(Float32, 4, 2) + points_ra = Reactant.to_rarray(points) + params_ra = Reactant.to_rarray(params) + + @test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params) +end + @testset "compilation cache" begin if Reactant.PersistentCompileCache.autotune_cache_enabled() && contains(string(Reactant.devices()[1]), "CUDA") @@ -1574,3 +1598,31 @@ end x_ra = Reactant.to_rarray(x) @test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32) end + +mapped_sub(xs...) = stack(map(-, xs...)) + +@testset "map of slices" begin + # We shouldn't be using `elem_apply` in this case and instead unroll the map + # our passes will fuse them backup if needed. + @testset "Vector of Slices" begin + x_full = rand(Float32, 10, 5, 3) + y_full = rand(Float32, 10, 5, 3) + x = [view(x_full, :, i, :) for i in 1:size(x_full, 2)] + y = [view(y_full, :, i, :) for i in 1:size(y_full, 2)] + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(mapped_sub(x_ra, y_ra)) ≈ mapped_sub(x, y) atol = 1e-5 rtol = 1e-5 + end + + @testset "Slices" begin + x_full = rand(Float32, 10, 5) + + @testset "ColumnSlices" begin + x_sliced = eachcol(x_full) + x_ra = Reactant.to_rarray(x_sliced) + + @test @jit(mapped_sub(x_ra)) ≈ mapped_sub(x_sliced) atol = 1e-5 rtol = 1e-5 + end + end +end