Skip to content

Commit 0f0ded5

Browse files
authored
@trace: Add way to configure checkpointing and mincut (#1390)
* `@trace`: Add way to configure checkpointing and mincut * fmt
1 parent a645d94 commit 0f0ded5

File tree

5 files changed

+85
-29
lines changed

5 files changed

+85
-29
lines changed

docs/src/tutorials/control-flow.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ In addition to conditional evaluations, [`@trace`](@ref) also supports capturing
126126
loops. This is possible in the form of both for and while loops.
127127
This enables one to write algorithm that would not be possible otherwise such as
128128
performing computations until convergence or running a computation for an certain
129-
number of iterations which is only known during runtime.
129+
number of iterations which is only known during runtime.
130130

131131
Here is an example of a function which computes the cumsum in non-optimized manner
132132
using a for loop:

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Returns true if this function is executed in a Reactant compilation context, oth
4343

4444
# Code generation
4545
"""
46-
@trace <expr>
46+
@trace [key = val,...] <expr>
4747
4848
Converts certain expressions like control flow into a Reactant friendly form. Importantly,
4949
if no traced value is found inside the expression, then there is no overhead.
@@ -53,7 +53,8 @@ if no traced value is found inside the expression, then there is no overhead.
5353
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
5454
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
5555
of the macro needs to be before the assignment and not before the `if`)
56-
- `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers.
56+
- `for` statements with a single induction variable iterating over integers with known `step`
57+
- `while` statements
5758
5859
## Special Considerations
5960
@@ -129,20 +130,43 @@ function fn(x)
129130
return y, nothing
130131
end
131132
```
133+
134+
### Configuration
135+
136+
The behavior of loops can be configured with the following configuration options:
137+
138+
- `track_numbers::Union{Bool,Datatype}` - whether Julia numbers should be automatically promoted to traced numbers upon entering the loop.
139+
- `checkpointing::Bool` - whether or not to enable checkpointing when performing reverse mode differentiation (default: `false`).
140+
- `mincut::Bool` - whether or not to enable the mincut algorithm when performing reverse mode differentiation (default: `false`).
132141
"""
133142
macro trace(args...)
134143
track_numbers = true
144+
checkpointing = false
145+
mincut = false
146+
135147
expr = first(args)
136-
if length(args) > 1 && Meta.isexpr(args[1], :(=))
137-
tn_expr = args[1]
138-
tn_expr.args[1] == :track_numbers ||
139-
error("@trace supports setting track_numbers, but got $(tn_expr)")
148+
while length(args) > 1
149+
if Meta.isexpr(args[1], :(=))
150+
tn_expr = args[1]
151+
key, val = tn_expr.args
152+
key (:track_numbers, :checkpointing, :mincut) || error(
153+
"@trace supports setting track_numbers, checkpointing or mincut, but got $(tn_expr)",
154+
)
140155

141-
track_numbers = tn_expr.args[2]
142-
expr = only(args[2:end])
143-
else
144-
expr = only(args)
156+
if key === :track_numbers
157+
track_numbers = val
158+
elseif key === :checkpointing
159+
checkpointing = val
160+
elseif key === :mincut
161+
mincut = val
162+
end
163+
args = args[2:end]
164+
else
165+
break
166+
end
145167
end
168+
expr = only(args)
169+
146170
track_numbers = track_numbers ? Number : Union{}
147171
expr = macroexpand(__module__, expr)
148172

@@ -159,14 +183,16 @@ macro trace(args...)
159183
return esc(trace_call(__module__, call))
160184
end
161185
Meta.isexpr(expr, :if) && return esc(trace_if(expr; track_numbers))
162-
Meta.isexpr(expr, :for) && return (esc(trace_for(expr; track_numbers)))
163-
Meta.isexpr(expr, :while) && return (esc(trace_while(expr; track_numbers)))
186+
Meta.isexpr(expr, :for) &&
187+
return (esc(trace_for(expr; track_numbers, checkpointing, mincut)))
188+
Meta.isexpr(expr, :while) &&
189+
return (esc(trace_while(expr; track_numbers, checkpointing, mincut)))
164190
return error(
165191
"Only `if-elseif-else` blocks, `for` and `while` loops are currently supported by `@trace`",
166192
)
167193
end
168194

169-
function trace_while(expr; track_numbers, first_arg=nothing)
195+
function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothing)
170196
Meta.isexpr(expr, :while, 2) || error("expected while expr")
171197
cond, body = expr.args
172198

@@ -233,6 +259,8 @@ function trace_while(expr; track_numbers, first_arg=nothing)
233259
$(args_sym);
234260
track_numbers=($(track_numbers)),
235261
verify_arg_names=($(verify_arg_names_sym)),
262+
mincut=($(mincut)),
263+
checkpointing=($(checkpointing)),
236264
)
237265
end
238266
end
@@ -247,7 +275,7 @@ function trace_while(expr; track_numbers, first_arg=nothing)
247275
end
248276
end
249277

250-
function trace_for(expr; track_numbers)
278+
function trace_for(expr; track_numbers, checkpointing, mincut)
251279
Meta.isexpr(expr, :for, 2) || error("expected for expr")
252280
assign, body = expr.args
253281

@@ -325,6 +353,8 @@ function trace_for(expr; track_numbers)
325353
);
326354
track_numbers,
327355
first_arg=counter,
356+
checkpointing,
357+
mincut,
328358
))
329359
end
330360
end

src/ControlFlow.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@ function ReactantCore.traced_call(f::Function, args...)
99
end
1010

1111
function ReactantCore.traced_while(
12-
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing
12+
cond_fn::CFn,
13+
body_fn::BFn,
14+
args;
15+
track_numbers=Number,
16+
verify_arg_names=nothing,
17+
checkpointing=false,
18+
mincut=false,
1319
) where {CFn,BFn}
14-
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names)
20+
return Ops.while_loop(
21+
cond_fn, body_fn, args...; track_numbers, verify_arg_names, checkpointing, mincut
22+
)
1523
end

src/Ops.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,13 @@ end
18671867
end
18681868

18691869
@noinline function while_loop(
1870-
cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing
1870+
cond_fn::CFn,
1871+
body_fn::BFn,
1872+
args...;
1873+
track_numbers,
1874+
verify_arg_names=nothing,
1875+
checkpointing=false,
1876+
mincut=false,
18711877
) where {CFn,BFn}
18721878
# TODO: detect and prevent mutation within the condition
18731879

@@ -1933,7 +1939,15 @@ end
19331939
cond=cond_reg,
19341940
body=body_reg,
19351941
)
1936-
MLIR.IR.attr!(while_op, "enzymexla.disable_min_cut", MLIR.IR.UnitAttribute())
1942+
1943+
if !mincut
1944+
MLIR.IR.attr!(while_op, "enzymexla.disable_min_cut", MLIR.IR.UnitAttribute())
1945+
end
1946+
1947+
if checkpointing
1948+
MLIR.IR.attr!(while_op, "enzymexla.enable_checkpointing", MLIR.IR.Attribute(true))
1949+
end
1950+
19371951
return map(enumerate(linear_args)) do (i, arg)
19381952
Reactant.TracedUtils.set_mlir_data!(arg, MLIR.IR.result(while_op, i))
19391953
end

test/control_flow.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,17 @@ function while_convergence(x, y)
676676
return diff
677677
end
678678

679+
@testset "while: convergence" begin
680+
x = [1.0, 10.0, 20.0]
681+
y = [0.0, -2.0, -3.0]
682+
x_ra = Reactant.to_rarray(x)
683+
y_ra = Reactant.to_rarray(y)
684+
685+
@test @jit(while_convergence(x_ra, y_ra)) while_convergence(x, y)
686+
end
687+
679688
function for_no_track_numbers(x, n)
680-
@trace track_numbers = false for i in n:16
689+
@trace mincut = false checkpointing = true track_numbers = false for i in n:16
681690
x = x .+ 1
682691
end
683692
return x
@@ -694,16 +703,11 @@ end
694703
for_no_track_numbers_ra = @compile optimize = "enzyme-batch" for_no_track_numbers(
695704
x_ra, n_ra
696705
)
697-
for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)
698-
end
699-
700-
@testset "while: convergence" begin
701-
x = [1.0, 10.0, 20.0]
702-
y = [0.0, -2.0, -3.0]
703-
x_ra = Reactant.to_rarray(x)
704-
y_ra = Reactant.to_rarray(y)
706+
@test for_no_track_numbers_ra(x_ra, n_ra) == for_no_track_numbers(x, n)
705707

706-
@test @jit(while_convergence(x_ra, y_ra)) while_convergence(x, y)
708+
ir = sprint(show, @code_hlo optimize = "enzyme-batch" for_no_track_numbers(x_ra, n_ra))
709+
@test contains(ir, "enzymexla.disable_min_cut")
710+
@test contains(ir, "enzymexla.enable_checkpointing")
707711
end
708712

709713
_call1(a, b) = a

0 commit comments

Comments
 (0)