Skip to content

Commit f020500

Browse files
feat: pass param grad from getindex
1 parent 0329949 commit f020500

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ end
9393
f.(u, Ref(_p), t)
9494
end
9595
gs = back(Δ)
96-
(gs[1], nothing)
96+
((u = gs[1], prob = (p = (tunable = gs[2],),)), nothing)
9797
elseif i === nothing
9898
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
9999
else
@@ -147,7 +147,7 @@ end
147147
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
148148
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
149149

150-
a = Zygote.accum(gs_obs[1], gs_not_obs)
150+
a = Zygote.accum(gs_obs[1], (u = gs_not_obs,))
151151

152152
(a, nothing)
153153
end

0 commit comments

Comments
 (0)