Skip to content

Commit ac9b145

Browse files
committed
refactor: cache the interpolation manually
As of [email protected] the parameter dependencies are no longer cached by MTKParameters, so we need to cache the interpolation object so that it's not recreated on every function call
1 parent 06b0fd8 commit ac9b145

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

src/Blocks/sources.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -738,9 +738,30 @@ function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any
738738
Symbolics.derivative(args[1], (args[2],), Val(1))
739739
end
740740

741-
@register_symbolic build_interpolation(
742-
interpolation_type::UnionAll, u::AbstractArray, x::AbstractArray, args::Tuple)
743-
build_interpolation(interpolation_type, u, x, args) = interpolation_type(u, x, args...)
741+
function cached_interpolation(interpolation_type, u, x, args)
742+
prev_u = Ref(collect(copy(u)))
743+
prev_x = Ref(collect(copy(x)))
744+
# MTKParameters use views, so we want to ensure that the type is the same
745+
interp = Ref(interpolation_type(prev_u[], prev_x[], args...))
746+
747+
let prev_u = prev_u,
748+
prev_x = prev_x,
749+
interp = interp,
750+
interpolation_type = interpolation_type
751+
function build_interpolation(u, x, args)
752+
if (u, x) (prev_u[], prev_x[])
753+
prev_u[] = collect(u)
754+
prev_x[] = collect(x)
755+
interp[] = interpolation_type(prev_u[], prev_x[], args...)
756+
else
757+
interp[]
758+
end
759+
end
760+
end
761+
end
762+
763+
@register_symbolic interpolation_builder(f::Function, u::AbstractArray, x::AbstractArray, args::Tuple)
764+
interpolation_builder(f, u, x, args) = f(u, x, args)
744765

745766
"""
746767
ParametrizedInterpolation(interp_type, u, x, args...; name)
@@ -762,20 +783,22 @@ such as `LinearInterpolation`, `ConstantInterpolation` or `CubicSpline`.
762783
# Connectors:
763784
- `output`: a [`RealOutput`](@ref) connector corresponding to the interpolated value
764785
"""
765-
function ParametrizedInterpolation(interp_type::T, u, x, args...; name) where {T}
786+
function ParametrizedInterpolation(interp_type::T, u::AbstractVector, x::AbstractVector, args...; name) where {T}
766787
@parameters data[1:length(x)] = u
767788
@parameters ts[1:length(x)] = x
768789
@parameters interpolation_type::T=interp_type [tunable = false] interpolation_args::Tuple=args [tunable = false]
769790
@parameters interpolator::interp_type
770791

792+
build_interpolation = cached_interpolation(interp_type, u, x, args)
793+
@parameters memoized_builder::typeof(build_interpolation)=build_interpolation [tunable = false]
794+
771795
@named output = RealOutput()
772796

773797
eqs = [output.u ~ apply_interpolation(interpolator, t)]
774798

775-
ODESystem(eqs, t, [], [u, x, interpolation_type, interpolator, interpolation_args];
799+
ODESystem(eqs, t, [], [data, ts, interpolation_type, interpolation_args, interpolator, memoized_builder];
776800
parameter_dependencies = [
777-
interpolator => build_interpolation(
778-
interpolation_type, u, x, interpolation_args)
801+
interpolator => interpolation_builder(memoized_builder, data, ts, interpolation_args)
779802
],
780803
systems = [output],
781804
name)

0 commit comments

Comments
 (0)