@@ -40,13 +40,13 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
4040end
4141function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
4242 val, order:: Int ;
43- custom_diff!, diff_cache)
43+ custom_diff!, diff_cache, eras_mode )
4444 return ChainRulesCore. zero_tangent (val)
4545end
4646function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
4747 arg:: Argument , order:: Int ;
48- custom_diff!, diff_cache)
49- recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
48+ custom_diff!, diff_cache, eras_mode )
49+ recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache, eras_mode )
5050 val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
5151 if val != = nothing
5252 return val
5656
5757function forward_diff_uncached! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
5858 ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ;
59- custom_diff!, diff_cache)
59+ custom_diff!, diff_cache, eras_mode )
6060 stmt = inst[:inst ]
61- recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache)
61+ recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode )
6262 if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
6363 return val
6464 elseif isa (stmt, PiNode)
@@ -212,8 +212,10 @@ Internal method which generates the code for forward mode diffentiation
212212 decides if the custom `transform!` should be applied to a `stmt` or not
213213 Default: `false` for all statements
214214 - `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
215+ - `eras_mode`: determines if to error if not all derivatives are taylor
215216"""
216217function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue,Int}} ;
218+ eras_mode = false ,
217219 visit_custom! = (@nospecialize args... )-> false ,
218220 transform! = (@nospecialize args... )-> error ())
219221 # Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
@@ -286,12 +288,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
286288 newargs = map (stmt. args[2 : end ]) do @nospecialize arg
287289 maparg (arg, SSAValue (ssa), order)
288290 end
289- replace_call! (ir, SSAValue (ssa), Expr (:call , ∂☆ {order} (), newargs... ))
291+ replace_call! (ir, SSAValue (ssa), Expr (:call , ∂☆ {order, eras_mode } (), newargs... ))
290292 elseif isexpr (stmt, :call ) || isexpr (stmt, :new )
291293 newargs = map (stmt. args) do @nospecialize arg
292294 maparg (arg, SSAValue (ssa), order)
293295 end
294- f = isexpr (stmt, :call ) ? ∂☆ {order} () : ∂☆new {order} ()
296+ f = isexpr (stmt, :call ) ? ∂☆ {order, eras_mode } () : ∂☆new {order} ()
295297 replace_call! (ir, SSAValue (ssa), Expr (:call , f, newargs... ))
296298 elseif isa (stmt, PiNode)
297299 # TODO : New PiNode that discriminates based on primal?
0 commit comments