Skip to content

Commit aa06f6d

Browse files
fix: replace Zygote adjoint with ChainRulesCore adjoint
1 parent a1aec24 commit aa06f6d

File tree

2 files changed

+14
-67
lines changed

2 files changed

+14
-67
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module SciMLBaseChainRulesCoreExt
22

33
using SciMLBase
4+
using SciMLBase: getobserved
45
import ChainRulesCore
5-
import ChainRulesCore: NoTangent, @non_differentiable
6+
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
67
using SymbolicIndexingInterface
78

89
function ChainRulesCore.rrule(
@@ -15,52 +16,28 @@ function ChainRulesCore.rrule(
1516
j::Integer)
1617
function ODESolution_getindex_pullback(Δ)
1718
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
18-
if i === nothing
19+
du, dprob = if i === nothing
1920
getter = getobserved(VA)
2021
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
21-
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
22-
dp = grz[3] # pullback for p
22+
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
23+
dp = grz[4] # pullback for p
24+
if dp == NoTangent()
25+
dp = zero_tangent(parameter_values(VA.prob))
26+
end
2327
dprob = remake(VA.prob, p = dp)
2428
T = eltype(eltype(VA.u))
2529
N = length(VA.prob.p)
26-
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
27-
typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing,
28-
nothing, nothing, nothing, dprob, nothing, nothing,
29-
VA.dense, 0, nothing, nothing, VA.retcode)
30-
(NoTangent(), Δ′, NoTangent(), NoTangent())
30+
du, dprob
3131
else
3232
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
3333
zero(VA.u[1]) for m in 1:length(VA.u)]
34-
dp = zero(VA.prob.p)
34+
dp = zero_tangent(VA.prob.p)
3535
dprob = remake(VA.prob, p = dp)
36-
Δ′ = ODESolution{
37-
T,
38-
N,
39-
typeof(du),
40-
Nothing,
41-
Nothing,
42-
typeof(VA.t),
43-
typeof(VA.k),
44-
typeof(dprob),
45-
typeof(VA.alg),
46-
typeof(VA.interp),
47-
typeof(VA.alg_choice),
48-
typeof(VA.stats)
49-
}(du,
50-
nothing,
51-
nothing,
52-
VA.t,
53-
VA.k,
54-
dprob,
55-
VA.alg,
56-
VA.interp,
57-
VA.dense,
58-
0,
59-
VA.stats,
60-
VA.alg_choice,
61-
VA.retcode)
62-
(NoTangent(), Δ′, NoTangent(), NoTangent())
36+
du, dprob
6337
end
38+
Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob,
39+
VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode)
40+
(NoTangent(), Δ′, NoTangent(), NoTangent())
6441
end
6542
VA[sym, j], ODESolution_getindex_pullback
6643
end

ext/SciMLBaseZygoteExt.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,36 +40,6 @@ import SciMLStructures
4040
VA[i, j], ODESolution_getindex_pullback
4141
end
4242

43-
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
44-
function ODESolution_getindex_pullback(Δ)
45-
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
46-
du, dprob = if i === nothing
47-
getter = getobserved(VA)
48-
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
49-
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
50-
dp = grz[3] # pullback for p
51-
if dp === nothing
52-
dp = parameter_values(VA)
53-
end
54-
dprob = remake(VA.prob, p = dp)
55-
du, dprob
56-
else
57-
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
58-
zero(VA.u[1]) for m in 1:length(VA.u)]
59-
dp = zero(VA.prob.p)
60-
dprob = remake(VA.prob, p = dp)
61-
du, dprob
62-
end
63-
T = eltype(eltype(VA.u))
64-
N = ndims(VA)
65-
Δ′ = ODESolution{T, N}(du, nothing, nothing,
66-
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
67-
VA.alg_choice, VA.retcode)
68-
(Δ′, nothing, nothing)
69-
end
70-
VA[sym, j], ODESolution_getindex_pullback
71-
end
72-
7343
@adjoint function EnsembleSolution(sim, time, converged, stats)
7444
out = EnsembleSolution(sim, time, converged, stats)
7545
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}

0 commit comments

Comments
 (0)