Skip to content

Commit 66c0730

Browse files
chore: get ext to compile
1 parent a6b7149 commit 66c0730

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function obs_grads(VA, sym, ::Nothing, Δ)
128128
Zygote.nt_nothing(VA)
129129
end
130130

131-
function not_obs_grads(VA::DESolution{T}, sym, not_obss_idx, i, Δ) where {T}
131+
function not_obs_grads(VA::DESolution, sym, not_obss_idx, i, Δ)
132132
Δ′ = map(enumerate(VA.u)) do (t_idx, us)
133133
map(enumerate(us)) do (u_idx, u)
134134
if u_idx in i
@@ -144,7 +144,7 @@ function not_obs_grads(VA::DESolution{T}, sym, not_obss_idx, i, Δ) where {T}
144144
end
145145

146146
@adjoint function Base.getindex(
147-
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where {T}
147+
VA::ODESolution, sym::Union{Tuple, AbstractVector})
148148
function ODESolution_getindex_pullback(Δ)
149149
sym = sym isa Tuple ? collect(sym) : sym
150150
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
@@ -165,8 +165,9 @@ end
165165
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
166166
function NonlinearSolution_getindex_pullback(Δ)
167167
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
168+
@show sym
168169
if is_observed(VA, sym)
169-
f = SII.observed(VA, sym)
170+
f = observed(VA, sym)
170171
p = parameter_values(VA)
171172
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
172173
u = state_values(VA)
@@ -176,7 +177,8 @@ end
176177
f.f_oop(_p, _p)
177178
end
178179
gs = back(Δ)
179-
(gs[1], nothing)
180+
# (gs[1], nothing)
181+
((u = gs[1], prob = (p = gs[2],)), nothing)
180182
elseif i === nothing
181183
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
182184
else
@@ -225,6 +227,7 @@ end
225227
uType2
226228
}
227229
function NonlinearSolutionAdjoint(ȳ)
230+
@show
228231
(ȳ, ntuple(_ -> nothing, length(args))...)
229232
end
230233
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
@@ -233,6 +236,7 @@ end
233236
@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution,
234237
::Val{:u})
235238
function solu_adjoint(Δ)
239+
@show "her"
236240
zerou = zero(sol.prob.u0)
237241
= @. ifelse=== nothing, zerou, Δ)
238242
(build_solution(sol.prob, sol.alg, _Δ, sol.resid),)

0 commit comments

Comments
 (0)