Skip to content

Commit 136475e

Browse files
sensitivity solution and downstream is passing
1 parent cd48367 commit 136475e

File tree

5 files changed

+74
-18
lines changed

5 files changed

+74
-18
lines changed

src/solutions/ode_solutions.jl

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,42 @@ function solution_new_retcode(sol::AbstractODESolution{T,N},retcode) where {T,N}
100100
sol.alg,sol.interp,sol.dense,sol.tslocation,sol.destats,retcode)
101101
end
102102

103-
function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
104-
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
105-
typeof(sol.t),typeof(sol.k),
106-
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
107-
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
108-
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
103+
function solution_new_tslocation(sol::AbstractODESolution{T,N},tslocation) where {T,N}
104+
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
105+
typeof(sol.t),typeof(sol.k),
106+
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
107+
sol.u,sol.u_analytic,sol.errors,sol.t,sol.k,sol.prob,
108+
sol.alg,sol.interp,sol.dense,tslocation,sol.destats,sol.retcode)
109+
end
110+
111+
function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
112+
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
113+
typeof(sol.t),typeof(sol.k),
114+
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
115+
sol.u[I],
116+
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
117+
sol.errors,sol.t[I],
118+
sol.dense ? sol.k[I] : sol.k,
119+
sol.prob,
120+
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
121+
end
122+
123+
function sensitivity_solution(sol::AbstractODESolution,u,t)
124+
T = eltype(eltype(u))
125+
N = length((size(sol.prob.u0)..., length(u)))
126+
interp = if typeof(sol.interp) <: LinearInterpolation
127+
LinearInterpolation(t,u)
128+
elseif typeof(sol.interp) <: ConstantInterpolation
129+
ConstantInterpolation(t,u)
130+
else
131+
SensitivityInterpolation(t,u)
109132
end
110133

111-
function solution_slice(sol::AbstractODESolution{T,N},I) where {T,N}
112-
ODESolution{T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),
113-
typeof(sol.t),typeof(sol.k),
114-
typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.destats)}(
115-
sol.u[I],
116-
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
117-
sol.errors,sol.t[I],
118-
sol.dense ? sol.k[I] : sol.k,
119-
sol.prob,
120-
sol.alg,sol.interp,false,sol.tslocation,sol.destats,sol.retcode)
121-
end
134+
ODESolution{T,N,typeof(u),typeof(sol.u_analytic),typeof(sol.errors),
135+
typeof(t),Nothing,typeof(sol.prob),typeof(sol.alg),
136+
typeof(interp),typeof(sol.destats)}(
137+
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
138+
sol.alg,interp,
139+
sol.dense,sol.tslocation,
140+
sol.destats,sol.retcode)
141+
end

src/solutions/rode_solutions.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,24 @@ function solution_slice(sol::AbstractRODESolution{T,N},I) where {T,N}
120120
false,sol.tslocation,sol.destats,
121121
sol.retcode,sol.seed)
122122
end
123+
124+
function sensitivity_solution(sol::AbstractRODESolution,u,t)
125+
T = eltype(eltype(u))
126+
N = length((size(sol.prob.u0)..., length(u)))
127+
interp = if typeof(sol.interp) <: LinearInterpolation
128+
LinearInterpolation(t,u)
129+
elseif typeof(sol.interp) <: ConstantInterpolation
130+
ConstantInterpolation(t,u)
131+
else
132+
SensitivityInterpolation(t,u)
133+
end
134+
135+
RODESolution{T,N,typeof(u),typeof(sol.u_analytic),
136+
typeof(sol.errors),typeof(t),
137+
typeof(nothing),typeof(sol.prob),typeof(sol.alg),
138+
typeof(sol.interp),typeof(sol.destats)}(
139+
u,sol.u_analytic,sol.errors,t,nothing,sol.prob,
140+
sol.alg,sol.interp,
141+
sol.dense,sol.tslocation,sol.destats,
142+
sol.retcode,sol.seed)
143+
end

src/solutions/steady_state_solutions.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ function build_solution(prob::AbstractSteadyStateProblem,
1818

1919
SteadyStateSolution{T,N,typeof(u),typeof(resid),typeof(prob),typeof(alg)}(u,resid,prob,alg,retcode)
2020
end
21+
22+
function sensitivity_solution(sol::AbstractSteadyStateSolution,u)
23+
T = eltype(eltype(u))
24+
N = length((size(sol.prob.u0)...,))
25+
26+
SteadyStateSolution{T,N,typeof(u),typeof(sol.resid),
27+
typeof(sol.prob),typeof(sol.alg)}(
28+
u,sol.resid,sol.prob,sol.alg,sol.retcode)
29+
end

src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
6161
end
6262
end
6363

64+
# save_idxs and saveat are here due to https://github.com/FluxML/Zygote.jl/issues/664
6465
function solve(prob::DEProblem,args...;sensealg=nothing,
65-
u0 = nothing, p = nothing,kwargs...)
66+
u0 = nothing, p = nothing, kwargs...)
6667
u0 = u0 !== nothing ? u0 : prob.u0
6768
p = p !== nothing ? p : prob.p
6869
solve_up(prob,sensealg,u0,p,args...;kwargs...)

src/zygote.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ end
7373

7474
ZygoteRules.@adjoint numargs(f) = (numargs(f),df->(nothing,))
7575
ChainRulesCore.rrule(::typeof(numargs),f) = (numargs(f),df->(nothing,))
76+
77+
# Until https://github.com/FluxML/Zygote.jl/issues/664 is fixed
78+
ZygoteRules.@adjoint function Base.pairs(x::NamedTuple)
79+
Base.pairs(x), Δ ->.data,)
80+
end

0 commit comments

Comments
 (0)