Skip to content

Commit bcf22a8

Browse files
fix: temporarily remove ambiguous adjoint
1 parent 5a5cd39 commit bcf22a8

File tree

1 file changed

+0
-45
lines changed

1 file changed

+0
-45
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -125,32 +125,6 @@ end
125125
VA[sym], ODESolution_getindex_pullback
126126
end
127127

128-
@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon)
129-
function ODESolution_getindex_pullback(Δ)
130-
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
131-
if is_observed(VA, sym)
132-
f = observed(VA, sym)
133-
p = parameter_values(VA)
134-
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
135-
u = state_values(VA)
136-
t = current_time(VA)
137-
y, back = Zygote.pullback(u, tunables) do u, tunables
138-
f.(u, Ref(tunables), t)
139-
end
140-
gs = back(Δ)
141-
(gs[1], nothing)
142-
elseif i === nothing
143-
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
144-
else
145-
@show i
146-
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
147-
for (x, j) in zip(VA.u, 1:length(VA))]
148-
(Δ′, nothing)
149-
end
150-
end
151-
VA[sym, :], ODESolution_getindex_pullback
152-
end
153-
154128
function obs_grads(VA, sym, obs_idx, Δ)
155129
y, back = Zygote.pullback(VA) do sol
156130
getindex.(Ref(sol), sym[obs_idx])
@@ -198,25 +172,6 @@ end
198172
VA[sym], ODESolution_getindex_pullback
199173
end
200174

201-
@adjoint function Base.getindex(
202-
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T}
203-
function ODESolution_getindex_pullback(Δ)
204-
sym = sym isa Tuple ? collect(sym) : sym
205-
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)
206-
207-
obs_idx = findall(s -> is_observed(VA, s), sym)
208-
not_obs_idx = setdiff(1:length(sym), obs_idx)
209-
210-
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
211-
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
212-
213-
a = Zygote.accum(gs_obs[1], gs_not_obs)
214-
215-
(a, nothing)
216-
end
217-
VA[sym], ODESolution_getindex_pullback
218-
end
219-
220175
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14
221176
}(u,
222177
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,

0 commit comments

Comments
 (0)