@@ -738,9 +738,30 @@ function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any
738738 Symbolics. derivative (args[1 ], (args[2 ],), Val (1 ))
739739end
740740
741- @register_symbolic build_interpolation (
742- interpolation_type:: UnionAll , u:: AbstractArray , x:: AbstractArray , args:: Tuple )
743- build_interpolation (interpolation_type, u, x, args) = interpolation_type (u, x, args... )
741+ function cached_interpolation (interpolation_type, u, x, args)
742+ prev_u = Ref (collect (copy (u)))
743+ prev_x = Ref (collect (copy (x)))
744+ # MTKParameters use views, so we want to ensure that the type is the same
745+ interp = Ref (interpolation_type (prev_u[], prev_x[], args... ))
746+
747+ let prev_u = prev_u,
748+ prev_x = prev_x,
749+ interp = interp,
750+ interpolation_type = interpolation_type
751+ function build_interpolation (u, x, args)
752+ if (u, x) ≠ (prev_u[], prev_x[])
753+ prev_u[] = collect (u)
754+ prev_x[] = collect (x)
755+ interp[] = interpolation_type (prev_u[], prev_x[], args... )
756+ else
757+ interp[]
758+ end
759+ end
760+ end
761+ end
762+
763+ @register_symbolic interpolation_builder (f:: Function , u:: AbstractArray , x:: AbstractArray , args:: Tuple )
764+ interpolation_builder (f, u, x, args) = f (u, x, args)
744765
745766"""
746767 ParametrizedInterpolation(interp_type, u, x, args...; name)
@@ -762,20 +783,22 @@ such as `LinearInterpolation`, `ConstantInterpolation` or `CubicSpline`.
762783# Connectors:
763784 - `output`: a [`RealOutput`](@ref) connector corresponding to the interpolated value
764785"""
765- function ParametrizedInterpolation (interp_type:: T , u, x, args... ; name) where {T}
786+ function ParametrizedInterpolation (interp_type:: T , u:: AbstractVector , x:: AbstractVector , args... ; name) where {T}
766787 @parameters data[1 : length (x)] = u
767788 @parameters ts[1 : length (x)] = x
768789 @parameters interpolation_type:: T = interp_type [tunable = false ] interpolation_args:: Tuple = args [tunable = false ]
769790 @parameters interpolator:: interp_type
770791
792+ build_interpolation = cached_interpolation (interp_type, u, x, args)
793+ @parameters memoized_builder:: typeof (build_interpolation)= build_interpolation [tunable = false ]
794+
771795 @named output = RealOutput ()
772796
773797 eqs = [output. u ~ apply_interpolation (interpolator, t)]
774798
775- ODESystem (eqs, t, [], [u, x , interpolation_type, interpolator, interpolation_args ];
799+ ODESystem (eqs, t, [], [data, ts , interpolation_type, interpolation_args, interpolator, memoized_builder ];
776800 parameter_dependencies = [
777- interpolator => build_interpolation (
778- interpolation_type, u, x, interpolation_args)
801+ interpolator => interpolation_builder (memoized_builder, data, ts, interpolation_args)
779802 ],
780803 systems = [output],
781804 name)
0 commit comments