Skip to content

Commit 9bcce26

Browse files
Add SensitivityInterpolation
1 parent a170fd9 commit 9bcce26

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

src/interpolation.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,41 @@ struct ConstantInterpolation{T1,T2} <: AbstractDiffEqInterpolation
2323
u::T2
2424
end
2525

26+
"""
27+
$(TYPEDEF)
28+
"""
29+
struct SensitivityInterpolation{T1,T2} <: AbstractDiffEqInterpolation
30+
t::T1
31+
u::T2
32+
end
33+
2634
interp_summary(::AbstractDiffEqInterpolation) = "Unknown"
2735
interp_summary(::HermiteInterpolation) = "3rd order Hermite"
2836
interp_summary(::LinearInterpolation) = "1st order linear"
2937
interp_summary(::ConstantInterpolation) = "Piecewise constant interpolation"
3038
interp_summary(::Nothing) = "No interpolation"
39+
interp_summary(::SensitivityInterpolation) = "Interpolation disabled due to sensitivity analysis"
3140
interp_summary(sol::DESolution) = interp_summary(sol.interp)
3241

42+
const SENSITIVITY_INTERP_MESSAGE =
43+
"""
44+
Standard interpolation is disabled due to sensitivity analysis being
45+
used for the gradients. Only linear and constant interpolations are
46+
compatible with non-AD sensitivity analysis calculations. Either
47+
utilize tooling like saveat to avoid post-solution interpolation, use
48+
the keyword argument dense=false for linear or constant interpolations,
49+
or use the keyword argument sensealg=SensitivityADPassThrough() to revert
50+
to AD-based derivatives.
51+
"""
52+
3353
(id::HermiteInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3454
(id::HermiteInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3555
(id::LinearInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3656
(id::LinearInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3757
(id::ConstantInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
3858
(id::ConstantInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
59+
(id::SensitivityInterpolation)(tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation(tvals,id,idxs,deriv,p,continuity)
60+
(id::SensitivityInterpolation)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = interpolation!(val,tvals,id,idxs,deriv,p,continuity)
3961

4062
@inline function interpolation(tvals,id,idxs,deriv,p,continuity::Symbol=:left)
4163
t = id.t; u = id.u
@@ -72,6 +94,7 @@ interp_summary(sol::DESolution) = interp_summary(sol.interp)
7294
vals[j] = u[i-1][idxs]
7395
end
7496
else
97+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
7598
dt = t[i] - t[i-1]
7699
Θ = (tval-t[i-1])/dt
77100
idxs_internal = idxs
@@ -119,6 +142,7 @@ times t (sorted), with values u and derivatives ks
119142
vals[j] = u[i-1][idxs]
120143
end
121144
else
145+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
122146
dt = t[i] - t[i-1]
123147
Θ = (tval-t[i-1])/dt
124148
idxs_internal = idxs
@@ -169,6 +193,7 @@ times t (sorted), with values u and derivatives ks
169193
val = u[i-1][idxs]
170194
end
171195
else
196+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
172197
dt = t[i] - t[i-1]
173198
Θ = (tval-t[i-1])/dt
174199
idxs_internal = idxs
@@ -211,6 +236,7 @@ times t (sorted), with values u and derivatives ks
211236
copy!(out,u[i-1][idxs])
212237
end
213238
else
239+
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
214240
dt = t[i] - t[i-1]
215241
Θ = (tval-t[i-1])/dt
216242
idxs_internal = idxs

0 commit comments

Comments
 (0)