1
1
module SciMLBaseChainRulesCoreExt
2
2
3
3
using SciMLBase
4
+ using SciMLBase: getobserved
4
5
import ChainRulesCore
5
- import ChainRulesCore: NoTangent, @non_differentiable
6
+ import ChainRulesCore: NoTangent, @non_differentiable , zero_tangent, rrule_via_ad
6
7
using SymbolicIndexingInterface
7
8
8
9
function ChainRulesCore. rrule (
@@ -15,52 +16,28 @@ function ChainRulesCore.rrule(
15
16
j:: Integer )
16
17
function ODESolution_getindex_pullback (Δ)
17
18
i = symbolic_type (sym) != NotSymbolic () ? variable_index (VA, sym) : sym
18
- if i === nothing
19
+ du, dprob = if i === nothing
19
20
getter = getobserved (VA)
20
21
grz = rrule_via_ad (config, getter, sym, VA. u[j], VA. prob. p, VA. t[j])[2 ](Δ)
21
- du = [k == j ? grz[2 ] : zero (VA. u[1 ]) for k in 1 : length (VA. u)]
22
- dp = grz[3 ] # pullback for p
22
+ du = [k == j ? grz[3 ] : zero (VA. u[1 ]) for k in 1 : length (VA. u)]
23
+ dp = grz[4 ] # pullback for p
24
+ if dp == NoTangent ()
25
+ dp = zero_tangent (parameter_values (VA. prob))
26
+ end
23
27
dprob = remake (VA. prob, p = dp)
24
28
T = eltype (eltype (VA. u))
25
29
N = length (VA. prob. p)
26
- Δ′ = ODESolution{T, N, typeof (du), Nothing, Nothing, Nothing, Nothing,
27
- typeof (dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing ,
28
- nothing , nothing , nothing , dprob, nothing , nothing ,
29
- VA. dense, 0 , nothing , nothing , VA. retcode)
30
- (NoTangent (), Δ′, NoTangent (), NoTangent ())
30
+ du, dprob
31
31
else
32
32
du = [m == j ? [i == k ? Δ : zero (VA. u[1 ][1 ]) for k in 1 : length (VA. u[1 ])] :
33
33
zero (VA. u[1 ]) for m in 1 : length (VA. u)]
34
- dp = zero (VA. prob. p)
34
+ dp = zero_tangent (VA. prob. p)
35
35
dprob = remake (VA. prob, p = dp)
36
- Δ′ = ODESolution{
37
- T,
38
- N,
39
- typeof (du),
40
- Nothing,
41
- Nothing,
42
- typeof (VA. t),
43
- typeof (VA. k),
44
- typeof (dprob),
45
- typeof (VA. alg),
46
- typeof (VA. interp),
47
- typeof (VA. alg_choice),
48
- typeof (VA. stats)
49
- }(du,
50
- nothing ,
51
- nothing ,
52
- VA. t,
53
- VA. k,
54
- dprob,
55
- VA. alg,
56
- VA. interp,
57
- VA. dense,
58
- 0 ,
59
- VA. stats,
60
- VA. alg_choice,
61
- VA. retcode)
62
- (NoTangent (), Δ′, NoTangent (), NoTangent ())
36
+ du, dprob
63
37
end
38
+ Δ′ = ODESolution {T, N} (du, nothing , nothing , VA. t, VA. k, nothing , dprob,
39
+ VA. alg, VA. interp, VA. dense, 0 , VA. stats, VA. alg_choice, VA. retcode)
40
+ (NoTangent (), Δ′, NoTangent (), NoTangent ())
64
41
end
65
42
VA[sym, j], ODESolution_getindex_pullback
66
43
end
0 commit comments