Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ function Base.mapreduce(
@nospecialize(op),
@nospecialize(A::AbstractConcreteArray{T,N});
dims=:,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
fn = compile(CallMapReduce(f, op, dims, init), (A,))
return fn(A)
Expand Down
8 changes: 6 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ macro opcall(expr)

# Generate location info at the callsite
location_expr = :($(mlir_stacktrace)(
joinpath(string(var"#self#"), $(string(func))),
if @isdefined(var"#self#")
joinpath(string(var"#self#"), $(string(func)))
else
$(string(func))
end,
$(string(__source__.file)),
$(__source__.line),
))
Expand Down Expand Up @@ -2575,7 +2579,7 @@ end
seen_cache = Reactant.OrderedIdDict()
Reactant.make_tracer(
seen_cache,
args,
fnwrapped ? (f, args) : args,
(), # we have to insert something here, but we remove it immediately below.
Reactant.TracedTrack;
toscalar=false,
Expand Down
5 changes: 4 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ end
end

@reactant_overlay @noinline function Base.mapreduce(
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs...
f,
op,
A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator};
kwargs...,
)
if use_overlayed_version(A)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)
Expand Down
20 changes: 17 additions & 3 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function _parent end
_parent_type(::Type{Array}) = Array
_parent_type(::Type{Array{T}}) where {T} = Array{T}
_parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N}
_parent_type(::Type{<:Slices{P}}) where {P} = P

include("accelerators/Accelerators.jl")

Expand Down Expand Up @@ -179,10 +180,15 @@ include("TracedRArray.jl")
include("ConcreteRArray.jl")

use_overlayed_version(x) = false
use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is)
function use_overlayed_version(x::F) where {F<:Function}
return use_overlayed_version(getfield.(Ref(x), fieldnames(F)))
end
use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter))
use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is)
use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr)
use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter))
use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x)
use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter))
use_overlayed_version(::TracedRArray) = true
use_overlayed_version(::TracedRNumber) = true
use_overlayed_version(::Number) = false
Expand All @@ -195,6 +201,14 @@ function use_overlayed_version(x::AbstractArray)
return use_overlayed_version(a)
end

## We avoid calling into `any` to avoid triggering the `any` overlay
function looped_any(f::F, itr) where {F}
@inbounds for x in itr
f(x) && return true
end
return false
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")
Expand Down
68 changes: 48 additions & 20 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
end

function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=Base._InitialValue()
)
res = unwrapped_broadcast(f, A)
# This means we are unable to use the optimized dispatches. For now we will
Expand All @@ -568,7 +568,7 @@ function overloaded_mapreduce(
@nospecialize(op),
@nospecialize(A::AnyTracedRArray{T,N});
dims=:,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
A = materialize_traced_array(A)

Expand All @@ -589,7 +589,7 @@ function overloaded_mapreduce(

res = @opcall reduce(reduce_input, reduce_init, dims, op)

init !== nothing && (res = op.(res, init))
(init isa Base._InitialValue || init === nothing) || (res = op.(res, init))

if original_dims isa Colon
@assert size(res) == () "expected size of result to be (), got $(size(res))"
Expand Down Expand Up @@ -677,6 +677,8 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
# Special case a union{} return so we can see the better error message
if ElType === Union{}
fn(map(first_scalar, bc.args)...)
elseif ElType == Any
ElType = eltype(fn(map(first_scalar, bc.args)...))
end
@assert ElType != Any && ElType != Union{}
sim = similar(bc, ElType)
Expand Down Expand Up @@ -1231,16 +1233,25 @@ function overloaded_map(f, x::AbstractArray, xs::AbstractArray...)
@assert allequal((axes(x), axes.(xs)...)) "Expected axes of all inputs to map to be \
equal"

needs_unrolling = falses(length(xs) + 1)
inputs = ()
for input in (x, xs...)
for (i, input) in enumerate((x, xs...))
if input isa AnyTracedRArray
input = Reactant.materialize_traced_array(input)
else
elseif eltype(input) <: Reactant.ReactantPrimitive
input = Reactant.promote_to(TracedRArray{eltype(input),ndims(input)}, input)
else
needs_unrolling[i] = true
end
inputs = (inputs..., input)
end

@assert allequal(needs_unrolling) "All inputs to `overloaded_map` must be \
unrolled or none of them. Open an issue."
if needs_unrolling[1]
length(inputs) == 1 && return unrolled_map(f, only(inputs))
return unrolled_map(splat(f), zip(inputs...))
end
return TracedUtils.elem_apply(f, inputs...)
end

Expand Down Expand Up @@ -1321,14 +1332,14 @@ function scan_impl!(
output::AnyTracedRArray{T,N},
input::AnyTracedRArray{T,N};
dims::Integer,
init=nothing,
init=Base._InitialValue(),
) where {T,N}
@assert dims > 0 "dims must be a positive integer"
@assert axes(output) == axes(input) "output and input must have the same shape"

dims > ndims(input) && return copyto!(output, input)

if init === nothing
if init isa Base._InitialValue
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
op_in_T === Union{} && (op_in_T = T)
init = __default_init(T, op)
Expand Down Expand Up @@ -1494,27 +1505,44 @@ struct BroadcastIterator{F}
f::F
end

(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
(fn::BroadcastIterator)(args...) = fn.f((args...,))

function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
min_length = Base.inferencebarrier(minimum)(length, x.is)
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
return (BroadcastIterator(f)).(itrs...)
else
fn = BroadcastIterator(f)
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
end
any(Base.Fix2(isa, AnyTracedRArray), itrs) || return unrolled_map(f, x)
return broadcast(BroadcastIterator(f), itrs...)
end

function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
if x.itr isa AnyTracedRArray
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
else
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
end
x.itr isa AnyTracedRArray || return unrolled_map(f, x)
return broadcast(
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
)
end

unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
unwrapped_broadcast(f::F, xs) where {F} = unrolled_map(f, xs)

# TODO: once traced_call supports internal mutations, we can use traced_call here
# TODO: we should overload this for Slices and use mapslices instead
function unrolled_map(f::F, itr) where {F}
y = Reactant.call_with_reactant(iterate, itr)
y === nothing && return []

first, state = y
res_first = Reactant.call_with_reactant(f, first)
result = [res_first]

while true
y = Reactant.call_with_reactant(iterate, itr, state)
y === nothing && break

val, state = y
res = Reactant.call_with_reactant(f, val)
push!(result, res)
end

return result
end

end
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ function finalize_mlir_fn(
skipped_results = Reactant.TracedType[]
for (k, v) in seen_results
v isa Reactant.TracedType || continue
if any(Base.Fix1(===, k), skipped_args)
if Reactant.looped_any(Base.Fix1(===, k), skipped_args)
push!(skipped_results, v)

_, argpath = get_argidx(v, argprefix)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CallWithReactant{F}
struct CallWithReactant{F} <: Function
f::F
end

Expand Down
2 changes: 1 addition & 1 deletion test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ end

@testset "Forward Gradient" begin
x = Reactant.to_rarray(3.1 * ones(2, 2))
res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x)
res = @jit gw(x)
# TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132
# to make sure this gets merged as a tracedrarray
@test res isa Tuple{<:Enzyme.TupleArray{<:ConcreteRNumber{Float64},(2, 2),4,2}}
Expand Down
68 changes: 60 additions & 8 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ end

ra = Reactant.to_rarray(x)
@jit dip!(ra)
ra[:a] ≈ (2.7 * 2) * ones(4)
@test ra[:a] ≈ (2.7 * 3.1) * ones(4)
end

@testset "@code_xla" begin
Expand Down Expand Up @@ -1429,7 +1429,10 @@ end
end

zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
zip_iterator2(a, b) = mapreduce(splat(.-), +, zip(a, b))
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
enumerate_iterator2(a) = mapreduce(splat(.-), +, enumerate(a))
mapreduce_vector(a) = mapreduce(-, +, a)

function nested_mapreduce_zip(x, y)
return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y)
Expand All @@ -1445,44 +1448,65 @@ function nested_mapreduce_hcat(x, y)
end
end

function f_generator(points, params)
return sum(params * point for point in points)
end

@testset "Base.Iterators" begin
@testset "zip" begin
N = 10
a = range(1.0, 5.0; length=N)
x = range(10.0, 15.0; length=N + 2)
a = collect(range(1.0, 5.0; length=N))
x = collect(range(10.0, 15.0; length=N + 2))
x_ra = Reactant.to_rarray(x)

@test @jit(zip_iterator(a, x_ra)) ≈ zip_iterator(a, x)

a = [rand(Float32, 2, 3) for _ in 1:10]
x = [rand(Float32, 2, 3) for _ in 1:10]
a_ra = Reactant.to_rarray(a)
x_ra = Reactant.to_rarray(x)

@test @jit(zip_iterator2(a_ra, x_ra)) ≈ zip_iterator2(a, x)
end

@testset "enumerate" begin
x = range(1.0, 5.0; length=10)
x = collect(range(1.0, 5.0; length=10))
x_ra = Reactant.to_rarray(x)

@test @jit(enumerate_iterator(x_ra)) ≈ enumerate_iterator(x)

x = [rand(Float32, 2, 3) for _ in 1:10]
x_ra = Reactant.to_rarray(x)

@test @jit(enumerate_iterator2(x_ra)) ≈ enumerate_iterator2(x)
end

@testset "nested mapreduce" begin
x = rand(Float32, 4, 3)
y = rand(Float32, 4, 3)

x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(nested_mapreduce_zip(x_ra, y_ra)) ≈ nested_mapreduce_zip(x, y)
end

@testset "nested mapreduce hcat" begin
x = rand(Float32, 4, 3)
y = rand(Float32, 4, 3)

x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y)
end
end

@testset "Base.Generator" begin
points = eachcol(rand(Float32, 2, 6))
params = rand(Float32, 4, 2)
points_ra = Reactant.to_rarray(points)
params_ra = Reactant.to_rarray(params)

@test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params)
end

@testset "compilation cache" begin
if Reactant.PersistentCompileCache.autotune_cache_enabled() &&
contains(string(Reactant.devices()[1]), "CUDA")
Expand Down Expand Up @@ -1574,3 +1598,31 @@ end
x_ra = Reactant.to_rarray(x)
@test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32)
end

mapped_sub(xs...) = stack(map(-, xs...))

@testset "map of slices" begin
# We shouldn't be using `elem_apply` in this case and instead unroll the map
# our passes will fuse them backup if needed.
@testset "Vector of Slices" begin
x_full = rand(Float32, 10, 5, 3)
y_full = rand(Float32, 10, 5, 3)
x = [view(x_full, :, i, :) for i in 1:size(x_full, 2)]
y = [view(y_full, :, i, :) for i in 1:size(y_full, 2)]
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(mapped_sub(x_ra, y_ra)) ≈ mapped_sub(x, y) atol = 1e-5 rtol = 1e-5
end

@testset "Slices" begin
x_full = rand(Float32, 10, 5)

@testset "ColumnSlices" begin
x_sliced = eachcol(x_full)
x_ra = Reactant.to_rarray(x_sliced)

@test @jit(mapped_sub(x_ra)) ≈ mapped_sub(x_sliced) atol = 1e-5 rtol = 1e-5
end
end
end
Loading