Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,22 @@ end

Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()

Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
arrayref(A, i1)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
arrayset(A, convert(T, x)::T, i1)
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref(
A, i1
)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset(
A, convert(T, x)::T, i1
)

# preserve the specific integer type when indexing device arrays,
# to avoid extending 32-bit hardware indices to 64-bit.
Base.to_index(::CuTracedArray, i::Integer) = i

# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
# See also: https://github.com/JuliaLang/julia/pull/42289
Base.@propagate_inbounds Base.getindex(
A::CuTracedArray, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)]
Base.@propagate_inbounds Base.getindex(A::CuTracedArray, I::Union{Integer,CartesianIndex}...) = A[Base._to_linear_index(
A, to_indices(A, I)...
)]
Base.@propagate_inbounds Base.setindex!(
A::CuTracedArray, x, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x
Expand Down
10 changes: 4 additions & 6 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,10 @@ function trace_for(mod, expr)
step = length(range.args) == 3 ? 1 : range.args[3]
limit = range.args[end]

body_symbols = ExpressionExplorer.compute_symbols_state(
quote
$(Expr(:local, assign))
$body
end,
)
body_symbols = ExpressionExplorer.compute_symbols_state(quote
$(Expr(:local, assign))
$body
end)

external_syms = body_symbols.assignments ∪ body_symbols.references
filter!(∉(SPECIAL_SYMBOLS), external_syms)
Expand Down
8 changes: 2 additions & 6 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
$(compiled_symbol) = $(compiler)(
$(f_symbol),
$(args_symbol);
fn_kwargs=$(kwargs_symbol),
fn_kwargs=($(kwargs_symbol)),
$(Expr.(:kw, keys(options), values(options))...),
)
end,
Expand Down Expand Up @@ -1396,11 +1396,7 @@ Generate Julia code to call the XLA executable.
- `nresults`: The number of results to expect.
"""
function codegen_xla_call(
flatten_names,
donated_args_mask,
nresults,
is_sharded::Bool,
ndevices::Int,
flatten_names, donated_args_mask, nresults, is_sharded::Bool, ndevices::Int
)
flatten_buffer_refs = map(n -> :($n.buffer), flatten_names)

Expand Down
6 changes: 4 additions & 2 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ function Base.setindex!(a::ConcretePJRTArray{T}, v, args::Vararg{Int,N}) where {
end

# TODO is there any way to allocate an uninitialized buffer in XLA?
function Base.similar(a::ConcretePJRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
function Base.similar(
a::ConcretePJRTArray{T}, (::Type{S})=T, dims::Dims=size(a)
) where {T,S}
return ConcretePJRTArray(
Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding
)
Expand Down Expand Up @@ -266,7 +268,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
if all(buffer_on_cpu, bc.args) && all(
x ->
!(x isa ConcretePJRTArray) ||
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
bc.args,
)
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
Expand Down
57 changes: 32 additions & 25 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
false,
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#set_reactant_abi,
)
end
else
Expand All @@ -100,11 +100,11 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
false,
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#set_reactant_abi,
)
end
end
Expand All @@ -116,20 +116,25 @@ const enzyme_dupnoneed = 3
const enzyme_outnoneed = 4
const enzyme_constnoneed = 5

@inline act_from_type(x, reverse, needs_primal=true) =
throw(AssertionError("Unhandled activity $(typeof(x))"))
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) =
act_from_type(Enzyme.Const, reverse, needs_primal)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(x, reverse, needs_primal=true) = throw(
AssertionError("Unhandled activity $(typeof(x))")
)
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type(
Enzyme.Const, reverse, needs_primal
)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.BatchDuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) =
act_from_type(Enzyme.Active, reverse, needs_primal)
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_type(
Enzyme.Active, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) =
if needs_primal
enzyme_const
Expand All @@ -151,10 +156,12 @@ const enzyme_constnoneed = 5
end
end

@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) =
act_from_type(Enzyme.DuplicatedNoNeed, Reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = act_from_type(
Enzyme.DuplicatedNoNeed, Reverse, needs_primal
)

@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) =
if needs_primal
Expand Down Expand Up @@ -487,7 +494,7 @@ function overload_autodiff(
false,
TracedUtils.transpose_val(MLIR.IR.result(res, residx));
emptypaths=true,
) #=reverse=#
)#=reverse=#
residx += 1
continue
end
Expand Down
65 changes: 31 additions & 34 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ for (dialect, op) in
res = MLIR.IR.result(
$(:($dialect.$op))(
x.mlir_data;
$(result)=mlir_type(TracedRArray{Bool,N}, size(x)),
($(result))=mlir_type(TracedRArray{Bool,N}, size(x)),
location,
),
)
Expand All @@ -304,7 +304,7 @@ for (dialect, op) in
) where {T}
res = MLIR.IR.result(
$(:($dialect.$op))(
x.mlir_data; $(result)=mlir_type(TracedRArray{Bool,0}, ()), location
x.mlir_data; ($(result))=mlir_type(TracedRArray{Bool,0}, ()), location
),
)
return TracedRNumber{Bool}((), res)
Expand Down Expand Up @@ -1057,15 +1057,14 @@ end
sample_inputs[2i - 1] = Reactant.ConcretePJRTNumber(T(0))
sample_inputs[2i] = Reactant.ConcretePJRTNumber(T(0))
end
func =
Reactant.TracedUtils.make_mlir_fn(
comparator,
(sample_inputs...,),
(),
"comparator";
args_in_result=:none,
return_dialect=:stablehlo,
).f
func = Reactant.TracedUtils.make_mlir_fn(
comparator,
(sample_inputs...,),
(),
"comparator";
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
Expand Down Expand Up @@ -1665,29 +1664,27 @@ end

input_types = [mlir_type(arg) for arg in linear_args]

cond_fn_compiled =
Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:none,
do_transpose=false,
).f

body_fn_compiled =
Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:none,
do_transpose=false,
).f
cond_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:none,
do_transpose=false,
).f

body_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:none,
do_transpose=false,
).f

cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled)
body_reg = Reactant.TracedUtils.__take_region(body_fn_compiled)
Expand Down
2 changes: 1 addition & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ for randfun in (:rand, :randn, :randexp)

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
rng::AbstractRNG, (::Type{T})=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ include("Compiler.jl")
include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false)
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
Expand Down
5 changes: 2 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,8 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
end

indices = [
(
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
).mlir_data for i in indices
(TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data
for i in indices
]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dynamic_update_slice(
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ for (jlop, hloop, hlocomp) in (
function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp))
return Ops.compare(lhs, rhs; comparison_direction=($(hlocomp)))
end

function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}
Expand Down
8 changes: 4 additions & 4 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ function materialize_traced_array(
end

get_mlir_data(x::TracedRNumber) = x.mlir_data
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x)
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data=data; return x)
get_paths(x::TracedRNumber) = x.paths
set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x)
set_paths!(x::TracedRNumber, paths) = (x.paths=paths; return x)

get_mlir_data(x::TracedRArray) = x.mlir_data
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))
get_paths(x::TracedRArray) = x.paths
set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x)
set_paths!(x::TracedRArray, paths) = (x.paths=paths; return x)

get_paths(x::MissingTracedValue) = x.paths
set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x)
set_paths!(x::MissingTracedValue, paths) = (x.paths=paths; return x)

function set_mlir_data!(x::TracedRArray, data)
x.mlir_data = data
Expand Down
Loading