@@ -86,7 +86,7 @@ The `vars` that are being reconstructed must appear in sorted order.
8686 T == AbstractVector{<: Number } ? Vector{Float64} :
8787 T
8888 end ... }
89- F! = JITOpaqueClosure {:reconstruct_derivative, goldclass_sig} () do arg_types...
89+ F! = JITOpaqueClosure {with_eps ? :reconstruct_derivative_with_eps : :reconstruct_derivative, goldclass_sig} () do arg_types...
9090 ir = copy (ir)
9191 ir. argtypes[2 : end ] .= arg_types
9292
@@ -192,9 +192,9 @@ function define_transform_for_reconstruct_der(var_assignment, vars, obs, param_b
192192 @assert with_eps
193193 eps_ii = epsnum (inst[:type ])
194194 input_basis_row = ntuple (neqs + nparams + neps) do active_state_ii
195- Float64 (var_ii == (active_state_ii - neqs + nparams))
195+ Float64 (eps_ii == (active_state_ii - neqs + nparams))
196196 end
197- replace_call! (ir, ssa, Expr (:call , BatchOfBundles{neqs + nparams + neps}, u_ii , input_basis_row... ))
197+ replace_call! (ir, ssa, Expr (:call , BatchOfBundles{neqs + nparams + neps}, 0. , input_basis_row... ))
198198 return nothing
199199 elseif is_solved_variable (stmt) || is_known_invoke (stmt, observed!, ir)
200200 if is_solved_variable (stmt)
@@ -248,7 +248,7 @@ function get_reconstruct_der_visit_custom!(var_assignment)
248248 end
249249
250250 stmt = ir[ssa][:inst ]
251- if is_known_invoke_or_call (stmt, variable, ir) || is_known_invoke_or_call (stmt, state_ddt, ir)
251+ if is_known_invoke_or_call (stmt, variable, ir) || is_known_invoke_or_call (stmt, state_ddt, ir) || is_known_invoke (stmt, epsilon, ir)
252252 return true
253253 elseif is_known_invoke_or_call (stmt, solved_variable, ir)
254254 recurse (stmt. args[end ])
0 commit comments