Skip to content

Commit e0e8dcf

Browse files
chore: simplify and generalize obs adjoint
1 parent 14db402 commit e0e8dcf

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,12 @@ end
168168
if is_observed(VA, sym)
169169
f = observed(VA, sym)
170170
p = parameter_values(VA)
171-
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
172171
u = state_values(VA)
173-
t = current_time(VA)
174-
y, back = Zygote.pullback(u, tunables) do u, tunables
175-
_p = repack(tunables)
176-
f.f_oop(u, _p)
172+
_, back = Zygote.pullback(u, p) do u, p
173+
f.f_oop(u, p)
177174
end
178175
gs = back(Δ)
179-
((u = gs[1], prob = (p = (tunable = gs[2],),)), nothing)
176+
((u = gs[1], prob = (p = gs[2],),), nothing)
180177
elseif i === nothing
181178
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
182179
else

0 commit comments

Comments
 (0)