Skip to content

Commit 2ae712b

Browse files
chore: handle when p is a functor in steady state adjoint
1 parent d3608c4 commit 2ae712b

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/concrete_solve.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,9 @@ function DiffEqBase._concrete_solve_adjoint(
17451745
t, _, _ = canonicalize(Tunable(), p)
17461746
t
17471747
end
1748+
elseif isfunctor(p)
1749+
ps, re = Functors.functor(p)
1750+
ps, x -> (re(x),)
17481751
else
17491752
nothing, x -> (x,)
17501753
end
@@ -1776,13 +1779,21 @@ function DiffEqBase._concrete_solve_adjoint(
17761779
dp, _, _ = canonicalize(Tunable(), dp)
17771780
dp, nothing
17781781
else
1779-
Δp = setproperties(dp, to_nt.prob.p))
1780-
Δtunables, _, _ = canonicalize(Tunable(), Δp)
1781-
dp, _, _ = canonicalize(Tunable(), dp)
1782-
dp, Δtunables
1782+
dp, Δtunables = if isscimlstructure(p)
1783+
Δp = setproperties(dp, to_nt.prob.p))
1784+
Δtunables, _, _ = canonicalize(Tunable(), Δp)
1785+
dp, _, _ = canonicalize(Tunable(), dp)
1786+
dp, Δtunables
1787+
elseif isfunctor(p)
1788+
dp, _ = Functors.functor(dp)
1789+
Δtunables, _ = Functors.functor.prob.p)
1790+
dp, Δtunables
1791+
else
1792+
dp, Δ.prob.p
1793+
end
17831794
end
17841795

1785-
dp = Zygote.accum(dp, Δtunables)
1796+
dp = Zygote.accum(dp, isempty(Δtunables) ? nothing : Δtunables)
17861797

17871798
if originator isa SciMLBase.TrackerOriginator ||
17881799
originator isa SciMLBase.ReverseDiffOriginator

0 commit comments

Comments
 (0)