Skip to content

Commit 02a9ca1

Browse files
Format code
1 parent 4e23a04 commit 02a9ca1

32 files changed

+278
-215
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,20 +160,22 @@ end
160160

161161
Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()
162162

163-
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
164-
arrayref(A, i1)
165-
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
166-
arrayset(A, convert(T, x)::T, i1)
163+
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref(
164+
A, i1
165+
)
166+
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset(
167+
A, convert(T, x)::T, i1
168+
)
167169

168170
# preserve the specific integer type when indexing device arrays,
169171
# to avoid extending 32-bit hardware indices to 64-bit.
170172
Base.to_index(::CuTracedArray, i::Integer) = i
171173

172174
# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
173175
# See also: https://github.com/JuliaLang/julia/pull/42289
174-
Base.@propagate_inbounds Base.getindex(
175-
A::CuTracedArray, I::Union{Integer,CartesianIndex}...
176-
) = A[Base._to_linear_index(A, to_indices(A, I)...)]
176+
Base.@propagate_inbounds Base.getindex(A::CuTracedArray, I::Union{Integer,CartesianIndex}...) = A[Base._to_linear_index(
177+
A, to_indices(A, I)...
178+
)]
177179
Base.@propagate_inbounds Base.setindex!(
178180
A::CuTracedArray, x, I::Union{Integer,CartesianIndex}...
179181
) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,10 @@ function trace_for(mod, expr)
157157
step = length(range.args) == 3 ? 1 : range.args[3]
158158
limit = range.args[end]
159159

160-
body_symbols = ExpressionExplorer.compute_symbols_state(
161-
quote
162-
$(Expr(:local, assign))
163-
$body
164-
end,
165-
)
160+
body_symbols = ExpressionExplorer.compute_symbols_state(quote
161+
$(Expr(:local, assign))
162+
$body
163+
end)
166164

167165
external_syms = body_symbols.assignments body_symbols.references
168166
filter!((SPECIAL_SYMBOLS), external_syms)

src/Compiler.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
10991099
$(compiled_symbol) = $(compiler)(
11001100
$(f_symbol),
11011101
$(args_symbol);
1102-
fn_kwargs=$(kwargs_symbol),
1102+
fn_kwargs=($(kwargs_symbol)),
11031103
$(Expr.(:kw, keys(options), values(options))...),
11041104
)
11051105
end,
@@ -1396,11 +1396,7 @@ Generate Julia code to call the XLA executable.
13961396
- `nresults`: The number of results to expect.
13971397
"""
13981398
function codegen_xla_call(
1399-
flatten_names,
1400-
donated_args_mask,
1401-
nresults,
1402-
is_sharded::Bool,
1403-
ndevices::Int,
1399+
flatten_names, donated_args_mask, nresults, is_sharded::Bool, ndevices::Int
14041400
)
14051401
flatten_buffer_refs = map(n -> :($n.buffer), flatten_names)
14061402

src/ConcreteRArray.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ function Base.setindex!(a::ConcretePJRTArray{T}, v, args::Vararg{Int,N}) where {
238238
end
239239

240240
# TODO is there any way to allocate an uninitialized buffer in XLA?
241-
function Base.similar(a::ConcretePJRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
241+
function Base.similar(
242+
a::ConcretePJRTArray{T}, (::Type{S})=T, dims::Dims=size(a)
243+
) where {T,S}
242244
return ConcretePJRTArray(
243245
Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding
244246
)
@@ -266,7 +268,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
266268
if all(buffer_on_cpu, bc.args) && all(
267269
x ->
268270
!(x isa ConcretePJRTArray) ||
269-
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
271+
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
270272
bc.args,
271273
)
272274
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)

src/Interpreter.jl

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ end
8383
ReactantCacheToken(),
8484
REACTANT_METHOD_TABLE,
8585
world,
86-
false, #=forward_rules=#
87-
false, #=reverse_rules=#
88-
false, #=inactive_rules=#
89-
false, #=broadcast_rewrite=#
90-
set_reactant_abi,
86+
false,
87+
#=forward_rules=#false,
88+
#=reverse_rules=#false,
89+
#=inactive_rules=#false,
90+
#=broadcast_rewrite=#set_reactant_abi,
9191
)
9292
end
9393
else
@@ -100,11 +100,11 @@ else
100100
REACTANT_CACHE,
101101
REACTANT_METHOD_TABLE,
102102
world,
103-
false, #=forward_rules=#
104-
false, #=reverse_rules=#
105-
false, #=inactive_rules=#
106-
false, #=broadcast_rewrite=#
107-
set_reactant_abi,
103+
false,
104+
#=forward_rules=#false,
105+
#=reverse_rules=#false,
106+
#=inactive_rules=#false,
107+
#=broadcast_rewrite=#set_reactant_abi,
108108
)
109109
end
110110
end
@@ -116,20 +116,25 @@ const enzyme_dupnoneed = 3
116116
const enzyme_outnoneed = 4
117117
const enzyme_constnoneed = 5
118118

119-
@inline act_from_type(x, reverse, needs_primal=true) =
120-
throw(AssertionError("Unhandled activity $(typeof(x))"))
121-
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) =
122-
act_from_type(Enzyme.Const, reverse, needs_primal)
123-
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) =
124-
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
119+
@inline act_from_type(x, reverse, needs_primal=true) = throw(
120+
AssertionError("Unhandled activity $(typeof(x))")
121+
)
122+
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type(
123+
Enzyme.Const, reverse, needs_primal
124+
)
125+
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type(
126+
Enzyme.Duplicated, reverse, needs_primal
127+
)
125128
@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) =
126129
reverse ? enzyme_out : enzyme_dupnoneed
127-
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) =
128-
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
130+
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = act_from_type(
131+
Enzyme.Duplicated, reverse, needs_primal
132+
)
129133
@inline act_from_type(::Enzyme.BatchDuplicatedNoNeed, reverse, needs_primal=true) =
130134
reverse ? enzyme_out : enzyme_dupnoneed
131-
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) =
132-
act_from_type(Enzyme.Active, reverse, needs_primal)
135+
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_type(
136+
Enzyme.Active, reverse, needs_primal
137+
)
133138
@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) =
134139
if needs_primal
135140
enzyme_const
@@ -151,10 +156,12 @@ const enzyme_constnoneed = 5
151156
end
152157
end
153158

154-
@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) =
155-
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
156-
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) =
157-
act_from_type(Enzyme.DuplicatedNoNeed, Reverse, needs_primal)
159+
@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = act_from_type(
160+
Enzyme.Duplicated, reverse, needs_primal
161+
)
162+
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = act_from_type(
163+
Enzyme.DuplicatedNoNeed, Reverse, needs_primal
164+
)
158165

159166
@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) =
160167
if needs_primal
@@ -487,7 +494,7 @@ function overload_autodiff(
487494
false,
488495
TracedUtils.transpose_val(MLIR.IR.result(res, residx));
489496
emptypaths=true,
490-
) #=reverse=#
497+
)#=reverse=#
491498
residx += 1
492499
continue
493500
end

src/Ops.jl

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ for (dialect, op) in
291291
res = MLIR.IR.result(
292292
$(:($dialect.$op))(
293293
x.mlir_data;
294-
$(result)=mlir_type(TracedRArray{Bool,N}, size(x)),
294+
($(result))=mlir_type(TracedRArray{Bool,N}, size(x)),
295295
location,
296296
),
297297
)
@@ -304,7 +304,7 @@ for (dialect, op) in
304304
) where {T}
305305
res = MLIR.IR.result(
306306
$(:($dialect.$op))(
307-
x.mlir_data; $(result)=mlir_type(TracedRArray{Bool,0}, ()), location
307+
x.mlir_data; ($(result))=mlir_type(TracedRArray{Bool,0}, ()), location
308308
),
309309
)
310310
return TracedRNumber{Bool}((), res)
@@ -1057,15 +1057,14 @@ end
10571057
sample_inputs[2i - 1] = Reactant.ConcretePJRTNumber(T(0))
10581058
sample_inputs[2i] = Reactant.ConcretePJRTNumber(T(0))
10591059
end
1060-
func =
1061-
Reactant.TracedUtils.make_mlir_fn(
1062-
comparator,
1063-
(sample_inputs...,),
1064-
(),
1065-
"comparator";
1066-
args_in_result=:none,
1067-
return_dialect=:stablehlo,
1068-
).f
1060+
func = Reactant.TracedUtils.make_mlir_fn(
1061+
comparator,
1062+
(sample_inputs...,),
1063+
(),
1064+
"comparator";
1065+
args_in_result=:none,
1066+
return_dialect=:stablehlo,
1067+
).f
10691068
@assert MLIR.IR.nregions(func) == 1
10701069
fn_name = String(
10711070
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
@@ -1665,29 +1664,27 @@ end
16651664

16661665
input_types = [mlir_type(arg) for arg in linear_args]
16671666

1668-
cond_fn_compiled =
1669-
Reactant.TracedUtils.make_mlir_fn(
1670-
cond_fn,
1671-
traced_args,
1672-
(),
1673-
string(gensym("cond_fn")),
1674-
false;
1675-
return_dialect=:stablehlo,
1676-
args_in_result=:none,
1677-
do_transpose=false,
1678-
).f
1679-
1680-
body_fn_compiled =
1681-
Reactant.TracedUtils.make_mlir_fn(
1682-
body_fn,
1683-
traced_args,
1684-
(),
1685-
string(gensym("body_fn")),
1686-
false;
1687-
return_dialect=:stablehlo,
1688-
args_in_result=:none,
1689-
do_transpose=false,
1690-
).f
1667+
cond_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
1668+
cond_fn,
1669+
traced_args,
1670+
(),
1671+
string(gensym("cond_fn")),
1672+
false;
1673+
return_dialect=:stablehlo,
1674+
args_in_result=:none,
1675+
do_transpose=false,
1676+
).f
1677+
1678+
body_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
1679+
body_fn,
1680+
traced_args,
1681+
(),
1682+
string(gensym("body_fn")),
1683+
false;
1684+
return_dialect=:stablehlo,
1685+
args_in_result=:none,
1686+
do_transpose=false,
1687+
).f
16911688

16921689
cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled)
16931690
body_reg = Reactant.TracedUtils.__take_region(body_fn_compiled)

src/Overlay.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ for randfun in (:rand, :randn, :randexp)
8080

8181
# scalars
8282
@reactant_overlay @noinline function Random.$(randfun)(
83-
rng::AbstractRNG, ::Type{T}=Float64
83+
rng::AbstractRNG, (::Type{T})=Float64
8484
) where {T}
8585
if T <: ReactantPrimitive
8686
return TracedRandom.$(overload_randfun)(rng, T)

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ include("Compiler.jl")
136136
include("Overlay.jl")
137137

138138
function Enzyme.make_zero(
139-
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
139+
::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false)
140140
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
141141
if haskey(seen, prev)
142142
return seen[prev]

src/TracedRArray.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,8 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
389389
end
390390

391391
indices = [
392-
(
393-
TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1
394-
).mlir_data for i in indices
392+
(TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data
393+
for i in indices
395394
]
396395
res = MLIR.IR.result(
397396
MLIR.Dialects.stablehlo.dynamic_update_slice(

src/TracedRNumber.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ for (jlop, hloop, hlocomp) in (
163163
function $(jlop)(
164164
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
165165
) where {T}
166-
return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp))
166+
return Ops.compare(lhs, rhs; comparison_direction=($(hlocomp)))
167167
end
168168

169169
function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}

0 commit comments

Comments
 (0)