@@ -43,7 +43,7 @@ Returns true if this function is executed in a Reactant compilation context, oth
43
43
44
44
# Code generation
45
45
"""
46
- @trace <expr>
46
+ @trace [key = val,...] <expr>
47
47
48
48
Converts certain expressions like control flow into a Reactant friendly form. Importantly,
49
49
if no traced value is found inside the expression, then there is no overhead.
@@ -53,7 +53,8 @@ if no traced value is found inside the expression, then there is no overhead.
53
53
- `if` conditions (with `elseif` and other niceties) (`@trace if ...`)
54
54
- `if` statements with a preceeding assignment (`@trace a = if ...`) (note the positioning
55
55
of the macro needs to be before the assignment and not before the `if`)
56
- - `for` statements with a single induction variable iterating over a syntactic `StepRange` of integers.
56
+ - `for` statements with a single induction variable iterating over integers with known `step`
57
+ - `while` statements
57
58
58
59
## Special Considerations
59
60
@@ -129,20 +130,43 @@ function fn(x)
129
130
return y, nothing
130
131
end
131
132
```
133
+
134
+ ### Configuration
135
+
136
+ The behavior of loops can be configured with the following configuration options:
137
+
138
+ - `track_numbers::Union{Bool,Datatype}` - whether Julia numbers should be automatically promoted to traced numbers upon entering the loop.
139
+ - `checkpointing::Bool` - whether or not to enable checkpointing when performing reverse mode differentiation (default: `false`).
140
+ - `mincut::Bool` - whether or not to enable the mincut algorithm when performing reverse mode differentiation (default: `false`).
132
141
"""
133
142
macro trace (args... )
134
143
track_numbers = true
144
+ checkpointing = false
145
+ mincut = false
146
+
135
147
expr = first (args)
136
- if length (args) > 1 && Meta. isexpr (args[1 ], :(= ))
137
- tn_expr = args[1 ]
138
- tn_expr. args[1 ] == :track_numbers ||
139
- error (" @trace supports setting track_numbers, but got $(tn_expr) " )
148
+ while length (args) > 1
149
+ if Meta. isexpr (args[1 ], :(= ))
150
+ tn_expr = args[1 ]
151
+ key, val = tn_expr. args
152
+ key ∈ (:track_numbers , :checkpointing , :mincut ) || error (
153
+ " @trace supports setting track_numbers, checkpointing or mincut, but got $(tn_expr) " ,
154
+ )
140
155
141
- track_numbers = tn_expr. args[2 ]
142
- expr = only (args[2 : end ])
143
- else
144
- expr = only (args)
156
+ if key === :track_numbers
157
+ track_numbers = val
158
+ elseif key === :checkpointing
159
+ checkpointing = val
160
+ elseif key === :mincut
161
+ mincut = val
162
+ end
163
+ args = args[2 : end ]
164
+ else
165
+ break
166
+ end
145
167
end
168
+ expr = only (args)
169
+
146
170
track_numbers = track_numbers ? Number : Union{}
147
171
expr = macroexpand (__module__, expr)
148
172
@@ -159,14 +183,16 @@ macro trace(args...)
159
183
return esc (trace_call (__module__, call))
160
184
end
161
185
Meta. isexpr (expr, :if ) && return esc (trace_if (expr; track_numbers))
162
- Meta. isexpr (expr, :for ) && return (esc (trace_for (expr; track_numbers)))
163
- Meta. isexpr (expr, :while ) && return (esc (trace_while (expr; track_numbers)))
186
+ Meta. isexpr (expr, :for ) &&
187
+ return (esc (trace_for (expr; track_numbers, checkpointing, mincut)))
188
+ Meta. isexpr (expr, :while ) &&
189
+ return (esc (trace_while (expr; track_numbers, checkpointing, mincut)))
164
190
return error (
165
191
" Only `if-elseif-else` blocks, `for` and `while` loops are currently supported by `@trace`" ,
166
192
)
167
193
end
168
194
169
- function trace_while (expr; track_numbers, first_arg= nothing )
195
+ function trace_while (expr; track_numbers, mincut, checkpointing, first_arg= nothing )
170
196
Meta. isexpr (expr, :while , 2 ) || error (" expected while expr" )
171
197
cond, body = expr. args
172
198
@@ -233,6 +259,8 @@ function trace_while(expr; track_numbers, first_arg=nothing)
233
259
$ (args_sym);
234
260
track_numbers= ($ (track_numbers)),
235
261
verify_arg_names= ($ (verify_arg_names_sym)),
262
+ mincut= ($ (mincut)),
263
+ checkpointing= ($ (checkpointing)),
236
264
)
237
265
end
238
266
end
@@ -247,7 +275,7 @@ function trace_while(expr; track_numbers, first_arg=nothing)
247
275
end
248
276
end
249
277
250
- function trace_for (expr; track_numbers)
278
+ function trace_for (expr; track_numbers, checkpointing, mincut )
251
279
Meta. isexpr (expr, :for , 2 ) || error (" expected for expr" )
252
280
assign, body = expr. args
253
281
@@ -325,6 +353,8 @@ function trace_for(expr; track_numbers)
325
353
);
326
354
track_numbers,
327
355
first_arg= counter,
356
+ checkpointing,
357
+ mincut,
328
358
))
329
359
end
330
360
end
0 commit comments