Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ function Base.mapreduce(
@nospecialize(op),
@nospecialize(A::AbstractConcreteArray{T,N});
dims=:,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
fn = compile(CallMapReduce(f, op, dims, init), (A,))
return fn(A)
Expand Down
8 changes: 6 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ macro opcall(expr)

# Generate location info at the callsite
location_expr = :($(mlir_stacktrace)(
joinpath(string(var"#self#"), $(string(func))),
if @isdefined(var"#self#")
joinpath(string(var"#self#"), $(string(func)))
else
$(string(func))
end,
$(string(__source__.file)),
$(__source__.line),
))
Expand Down Expand Up @@ -2575,7 +2579,7 @@ end
seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
fnwrapped ? (f, args) : args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
Expand Down
5 changes: 4 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ end
end

@reactant_overlay @noinline function Base.mapreduce(
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs...
f,
op,
A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator};
kwargs...,
)
if use_overlayed_version(A)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)
Expand Down
20 changes: 17 additions & 3 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function _parent end
_parent_type(::Type{Array}) = Array
_parent_type(::Type{Array{T}}) where {T} = Array{T}
_parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N}
_parent_type(::Type{<:Slices{P}}) where {P} = P

include("accelerators/Accelerators.jl")

Expand Down Expand Up @@ -179,10 +180,15 @@ include("TracedRArray.jl")
include("ConcreteRArray.jl")

use_overlayed_version(x) = false
use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is)
function use_overlayed_version(x::F) where {F<:Function}
return use_overlayed_version(getfield.(Ref(x), fieldnames(F)))
end
use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter))
use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is)
use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr)
use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter))
use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x)
use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter))
use_overlayed_version(::TracedRArray) = true
use_overlayed_version(::TracedRNumber) = true
use_overlayed_version(::Number) = false
Expand All @@ -195,6 +201,14 @@ function use_overlayed_version(x::AbstractArray)
return use_overlayed_version(a)
end

## We avoid calling into `any` to avoid triggering the `any` overlay
function looped_any(f::F, itr) where {F}
@inbounds for x in itr
f(x) && return true
end
return false
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")
Expand Down
52 changes: 40 additions & 12 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
end

function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
)
res = unwrapped_broadcast(f, A)
# This means we are unable to use the optimized dispatches. For now we will
Expand All @@ -568,7 +568,7 @@ function overloaded_mapreduce(
@nospecialize(op),
@nospecialize(A::AnyTracedRArray{T,N});
dims=:,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
A = materialize_traced_array(A)

Expand All @@ -589,7 +589,7 @@ function overloaded_mapreduce(

res = @opcall reduce(reduce_input, reduce_init, dims, op)

init !== nothing && (res = op.(res, init))
(init isa Base._InitialValue || init === nothing) || (res = op.(res, init))

if original_dims isa Colon
@assert size(res) == () "expected size of result to be (), got $(size(res))"
Expand Down Expand Up @@ -677,6 +677,8 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
# Special case a union{} return so we can see the better error message
if ElType === Union{}
fn(map(first_scalar, bc.args)...)
elseif ElType == Any
ElType = eltype(fn(map(first_scalar, bc.args)...))
end
@assert ElType != Any && ElType != Union{}
sim = similar(bc, ElType)
Expand Down Expand Up @@ -1321,14 +1323,14 @@ function scan_impl!(
output::AnyTracedRArray{T,N},
input::AnyTracedRArray{T,N};
dims::Integer,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
@assert dims > 0 "dims must be a positive integer"
@assert axes(output) == axes(input) "output and input must have the same shape"

dims > ndims(input) && return copyto!(output, input)

if init === nothing
if init isa Base._InitialValue
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
op_in_T === Union{} && (op_in_T = T)
init = __default_init(T, op)
Expand Down Expand Up @@ -1494,27 +1496,53 @@ struct BroadcastIterator{F}
f::F
end

(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
(fn::BroadcastIterator)(args...) = fn.f((args...,))

function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
min_length = Base.inferencebarrier(minimum)(length, x.is)
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
return (BroadcastIterator(f)).(itrs...)
return broadcast(BroadcastIterator(f), itrs...)
else
fn = BroadcastIterator(f)
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
return unwrapped_broadcast_with_iterate(f, x)
end
end

function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
if x.itr isa AnyTracedRArray
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
return broadcast(
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
)
else
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
return unwrapped_broadcast_with_iterate(f, x)
end
end

unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
function unwrapped_broadcast(f::F, x::Base.Generator) where {F}
return unwrapped_broadcast_with_iterate(f, x)
end

unwrapped_broadcast(f::F, xs) where {F} = unwrapped_broadcast_with_iterate(f, xs)

# TODO: once traced_call supports internal mutations, we can use traced_call here
function unwrapped_broadcast_with_iterate(f::F, itr) where {F}
y = Reactant.call_with_reactant(iterate, itr)
y === nothing && return []

first, state = y
res_first = Reactant.call_with_reactant(f, first)
result = [res_first]

while true
y = Reactant.call_with_reactant(iterate, itr, state)
y === nothing && break

val, state = y
res = Reactant.call_with_reactant(f, val)
push!(result, res)
end

return result
end

end
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ function finalize_mlir_fn(
skipped_results = Reactant.TracedType[]
for (k, v) in seen_results
v isa Reactant.TracedType || continue
if any(Base.Fix1(===, k), skipped_args)
if Reactant.looped_any(Base.Fix1(===, k), skipped_args)
push!(skipped_results, v)

_, argpath = get_argidx(v, argprefix)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CallWithReactant{F}
struct CallWithReactant{F} <: Function
f::F
end

Expand Down
2 changes: 1 addition & 1 deletion test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ end

@testset "Forward Gradient" begin
x = Reactant.to_rarray(3.1 * ones(2, 2))
res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x)
res = @jit gw(x)
# TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132
# to make sure this gets merged as a tracedrarray
@test res isa Tuple{<:Enzyme.TupleArray{<:ConcreteRNumber{Float64},(2, 2),4,2}}
Expand Down
40 changes: 32 additions & 8 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ end

ra = Reactant.to_rarray(x)
@jit dip!(ra)
ra[:a] ≈ (2.7 * 2) * ones(4)
@test ra[:a] ≈ (2.7 * 3.1) * ones(4)
end

@testset "@code_xla" begin
Expand Down Expand Up @@ -1429,7 +1429,10 @@ end
end

zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
zip_iterator2(a, b) = mapreduce(splat(.-), +, zip(a, b))
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
enumerate_iterator2(a) = mapreduce(splat(.-), +, enumerate(a))
mapreduce_vector(a) = mapreduce(-, +, a)

function nested_mapreduce_zip(x, y)
return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y)
Expand All @@ -1445,44 +1448,65 @@ function nested_mapreduce_hcat(x, y)
end
end

function f_generator(points, params)
return sum(params * point for point in points)
end

@testset "Base.Iterators" begin
@testset "zip" begin
N = 10
a = range(1.0, 5.0; length=N)
x = range(10.0, 15.0; length=N + 2)
a = collect(range(1.0, 5.0; length=N))
x = collect(range(10.0, 15.0; length=N + 2))
x_ra = Reactant.to_rarray(x)

@test @jit(zip_iterator(a, x_ra)) ≈ zip_iterator(a, x)

a = [rand(Float32, 2, 3) for _ in 1:10]
x = [rand(Float32, 2, 3) for _ in 1:10]
a_ra = Reactant.to_rarray(a)
x_ra = Reactant.to_rarray(x)

@test @jit(zip_iterator2(a_ra, x_ra)) ≈ zip_iterator2(a, x)
end

@testset "enumerate" begin
x = range(1.0, 5.0; length=10)
x = collect(range(1.0, 5.0; length=10))
x_ra = Reactant.to_rarray(x)

@test @jit(enumerate_iterator(x_ra)) ≈ enumerate_iterator(x)

x = [rand(Float32, 2, 3) for _ in 1:10]
x_ra = Reactant.to_rarray(x)

@test @jit(enumerate_iterator2(x_ra)) ≈ enumerate_iterator2(x)
end

@testset "nested mapreduce" begin
x = rand(Float32, 4, 3)
y = rand(Float32, 4, 3)

x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(nested_mapreduce_zip(x_ra, y_ra)) ≈ nested_mapreduce_zip(x, y)
end

@testset "nested mapreduce hcat" begin
x = rand(Float32, 4, 3)
y = rand(Float32, 4, 3)

x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y)
end
end

@testset "Base.Generator" begin
points = eachcol(rand(Float32, 2, 6))
params = rand(Float32, 4, 2)
points_ra = Reactant.to_rarray(points)
params_ra = Reactant.to_rarray(params)

@test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params)
end

@testset "compilation cache" begin
if Reactant.PersistentCompileCache.autotune_cache_enabled() &&
contains(string(Reactant.devices()[1]), "CUDA")
Expand Down
Loading