|
125 | 125 | VA[sym], ODESolution_getindex_pullback |
126 | 126 | end |
127 | 127 |
|
128 | | -@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon) |
129 | | - function ODESolution_getindex_pullback(Δ) |
130 | | - i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym |
131 | | - if is_observed(VA, sym) |
132 | | - f = observed(VA, sym) |
133 | | - p = parameter_values(VA) |
134 | | - tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) |
135 | | - u = state_values(VA) |
136 | | - t = current_time(VA) |
137 | | - y, back = Zygote.pullback(u, tunables) do u, tunables |
138 | | - f.(u, Ref(tunables), t) |
139 | | - end |
140 | | - gs = back(Δ) |
141 | | - (gs[1], nothing) |
142 | | - elseif i === nothing |
143 | | - throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) |
144 | | - else |
145 | | - @show i |
146 | | - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] |
147 | | - for (x, j) in zip(VA.u, 1:length(VA))] |
148 | | - (Δ′, nothing) |
149 | | - end |
150 | | - end |
151 | | - VA[sym, :], ODESolution_getindex_pullback |
152 | | -end |
153 | | - |
154 | 128 | function obs_grads(VA, sym, obs_idx, Δ) |
155 | 129 | y, back = Zygote.pullback(VA) do sol |
156 | 130 | getindex.(Ref(sol), sym[obs_idx]) |
|
198 | 172 | VA[sym], ODESolution_getindex_pullback |
199 | 173 | end |
200 | 174 |
|
201 | | -@adjoint function Base.getindex( |
202 | | - VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T} |
203 | | - function ODESolution_getindex_pullback(Δ) |
204 | | - sym = sym isa Tuple ? collect(sym) : sym |
205 | | - i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) |
206 | | - |
207 | | - obs_idx = findall(s -> is_observed(VA, s), sym) |
208 | | - not_obs_idx = setdiff(1:length(sym), obs_idx) |
209 | | - |
210 | | - gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ) |
211 | | - gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) |
212 | | - |
213 | | - a = Zygote.accum(gs_obs[1], gs_not_obs) |
214 | | - |
215 | | - (a, nothing) |
216 | | - end |
217 | | - VA[sym], ODESolution_getindex_pullback |
218 | | -end |
219 | | - |
220 | 175 | @adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14 |
221 | 176 | }(u, |
222 | 177 | args...) where {T1, T2, T3, T4, T5, T6, T7, T8, |
|
0 commit comments