Skip to content

Commit 9afb5c5

Browse files
authored
Fix epsilon derivatives (#5)
Fixes inverter_noise test in Cedar.
1 parent ef6e297 commit 9afb5c5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/transform/state_reconstruct/derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ The `vars` that are being reconstructed must appear in sorted order.
8686
T == AbstractVector{<:Number} ? Vector{Float64} :
8787
T
8888
end...}
89-
F! = JITOpaqueClosure{:reconstruct_derivative, goldclass_sig}() do arg_types...
89+
F! = JITOpaqueClosure{with_eps ? :reconstruct_derivative_with_eps : :reconstruct_derivative, goldclass_sig}() do arg_types...
9090
ir = copy(ir)
9191
ir.argtypes[2:end] .= arg_types
9292

@@ -192,9 +192,9 @@ function define_transform_for_reconstruct_der(var_assignment, vars, obs, param_b
192192
@assert with_eps
193193
eps_ii = epsnum(inst[:type])
194194
input_basis_row = ntuple(neqs + nparams + neps) do active_state_ii
195-
Float64(var_ii == (active_state_ii - neqs + nparams))
195+
Float64(eps_ii == (active_state_ii - neqs + nparams))
196196
end
197-
replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, u_ii, input_basis_row...))
197+
replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, 0., input_basis_row...))
198198
return nothing
199199
elseif is_solved_variable(stmt) || is_known_invoke(stmt, observed!, ir)
200200
if is_solved_variable(stmt)
@@ -248,7 +248,7 @@ function get_reconstruct_der_visit_custom!(var_assignment)
248248
end
249249

250250
stmt = ir[ssa][:inst]
251-
if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir)
251+
if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir) || is_known_invoke(stmt, epsilon, ir)
252252
return true
253253
elseif is_known_invoke_or_call(stmt, solved_variable, ir)
254254
recurse(stmt.args[end])

0 commit comments

Comments
 (0)