@@ -182,7 +182,22 @@ function forward_visit!(ir::IRCode, a::Argument, order::Int, ssa_orders::Vector{
182
182
end
183
183
184
184
185
- function forward_diff_no_inf! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ;
185
+ """
186
+ forward_diff_no_inf!(ir, to_diff)
187
+
188
+ Internal method which generates the code for forward mode diffentiation
189
+
190
+
191
+ - `ir` the IR being differnetation
192
+ - `to_diff`: collection of all SSA values for which the derivative is to be taken,
193
+ paired with the order (first deriviative, second derivative etc)
194
+
195
+ - `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
196
+ decides if the custom `transform!` should be applied to a `stmt` or not
197
+ Default: `false` for all statements
198
+ - `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
199
+ """
200
+ function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue, Int}} ;
186
201
visit_custom! = (args... )-> false , transform! = (args... )-> error ())
187
202
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
188
203
ssa_orders = [0 => false for i = 1 : length (ir. stmts)]
@@ -271,11 +286,11 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
271
286
end
272
287
end
273
288
end
274
-
275
289
end
276
290
291
+
277
292
function forward_diff! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ; kwargs... )
278
- forward_diff_no_inf! (ir, interp, mi, world, to_diff; kwargs... )
293
+ forward_diff_no_inf! (ir, to_diff; kwargs... )
279
294
280
295
# Step 3: Re-inference
281
296
ir = compact! (ir)
0 commit comments