Skip to content

Commit 854e4fd

Browse files
fix: fix ODESolution-related adjoints
1 parent cea0536 commit 854e4fd

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import SciMLStructures
3333
N = length((size(dprob.u0)..., length(du)))
3434
end
3535
Δ′ = ODESolution{T, N}(du, nothing, nothing,
36-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
36+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
3737
VA.alg_choice, VA.retcode)
3838
(Δ′, nothing, nothing)
3939
end
@@ -60,7 +60,7 @@ end
6060
T = eltype(eltype(VA.u))
6161
N = ndims(VA)
6262
Δ′ = ODESolution{T, N}(du, nothing, nothing,
63-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
63+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
6464
VA.alg_choice, VA.retcode)
6565
(Δ′, nothing, nothing)
6666
end
@@ -117,9 +117,11 @@ end
117117
elseif i === nothing
118118
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
119119
else
120-
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
121-
for (x, j) in zip(VA.u, 1:length(VA))]
122-
(Δ′, nothing)
120+
VA = recursivecopy(VA)
121+
recursivefill!(VA, zero(eltype(VA)))
122+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
123+
copyto!(v, Δ)
124+
(VA, nothing)
123125
end
124126
end
125127
VA[sym], ODESolution_getindex_pullback
@@ -172,15 +174,15 @@ end
172174
VA[sym], ODESolution_getindex_pullback
173175
end
174176

175-
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14
176-
}(u,
177+
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
178+
, T14, T15}(u,
177179
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
178-
T9, T10, T11, T12, T13, T14}
180+
T9, T10, T11, T12, T13, T14, T15}
179181
function ODESolutionAdjoint(ȳ)
180182
(ȳ, ntuple(_ -> nothing, length(args))...)
181183
end
182184

183-
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14}(u, args...),
185+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...),
184186
ODESolutionAdjoint
185187
end
186188

0 commit comments

Comments
 (0)