@@ -733,43 +733,86 @@ function SampledData(; name, buffer, sample_time, circular_buffer)
733733end 
734734
735735#  This needs to be extend for interpolation types
736- apply_interpolation (interp, t) =  interp (t)
736+ #  apply_interpolation(interp, t) = interp(t)
737+ 
738+ #  function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any}, ::Val{2})
739+ #      Symbolics.derivative(args[1], (args[2],), Val(1))
740+ #  end
741+ 
742+ #  function cached_interpolation(interpolation_type, u, x, args)
743+ #      prev_u = DiffCache(u)
744+ #      # Interpolation points can be a range, but we want to be able
745+ #      # to update the cache if needed (and setindex! is not defined on ranges)
746+ #      # with a view from MTKParameters, so we collect to get a vector
747+ #      prev_x = DiffCache(collect(x))
748+ #      interp = GeneralLazyBufferCache() do (u, x)
749+ #          interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
750+ #      end
751+ 
752+ #      let prev_u = prev_u,
753+ #          prev_x = prev_x,
754+ #          interp = interp,
755+ #          interpolation_type = interpolation_type
756+ 
757+ #          function build_interpolation(u, x, args)
758+ #              if (u, x) ≠ (get_tmp(prev_u, u), get_tmp(prev_x, x))
759+ #                  get_tmp(prev_u, u) .= u
760+ #                  get_tmp(prev_x, x) .= x
761+ #                  interp.bufs[(u, x)] = interpolation_type(
762+ #                      get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
763+ #              else
764+ #                  interp[(u, x)]
765+ #              end
766+ #          end
767+ #      end
768+ #  end
769+ 
770+ struct  CachedInterpolation{T,I,U,X,C}
771+     interpolation_type:: I 
772+     prev_u:: U 
773+     prev_x:: X 
774+     cache:: C 
775+ 
776+     function  CachedInterpolation (interpolation_type, u, x, args)
777+         #  we need to copy the inputs to avoid aliasing
778+         prev_u =  DiffCache (copy (u))
779+         #  Interpolation points can be a range, but we want to be able
780+         #  to update the cache if needed (and setindex! is not defined on ranges)
781+         #  with a view from MTKParameters, so we collect to get a vector
782+         prev_x =  DiffCache (collect (copy (x)))
783+         cache =  GeneralLazyBufferCache () do  (u, x)
784+             interpolation_type (get_tmp (prev_u, u), get_tmp (prev_x, x), args... )
785+         end 
786+         T =  typeof (cache[(get_tmp (prev_u, u), get_tmp (prev_x, x))])
787+         I =  typeof (interpolation_type)
788+         U =  typeof (prev_u)
789+         X =  typeof (prev_x)
790+         C =  typeof (cache)
737791
738- function  Symbolics . derivative ( :: typeof (apply_interpolation), args :: NTuple{2, Any} ,  :: Val{2} )
739-     Symbolics . derivative (args[ 1 ], (args[ 2 ],),  Val ( 1 )) 
792+          new {T,I,U,X,C} (interpolation_type, prev_u, prev_x, cache )
793+     end 
740794end 
741795
742- function  cached_interpolation (interpolation_type, u, x, args)
743-     prev_u =  DiffCache (u)
744-     #  Interpolation points can be a range, but we want to be able
745-     #  to update the cache if needed (and setindex! is not defined on ranges)
746-     #  with a view from MTKParameters, so we collect to get a vector
747-     prev_x =  DiffCache (collect (x))
748-     interp =  GeneralLazyBufferCache () do  (u, x)
749-         interpolation_type (get_tmp (prev_u, u), get_tmp (prev_x, x), args... )
750-     end 
796+ function  (f:: CachedInterpolation{T} )(u, x, args) where  T
797+     (;prev_u, prev_x, cache, interpolation_type) =  f
751798
752-     let  prev_u =  prev_u,
753-         prev_x =  prev_x,
754-         interp =  interp,
755-         interpolation_type =  interpolation_type
756- 
757-         function  build_interpolation (u, x, args)
758-             if  (u, x) ≠  (get_tmp (prev_u, u), get_tmp (prev_x, x))
759-                 get_tmp (prev_u, u) .=  u
760-                 get_tmp (prev_x, x) .=  x
761-                 interp. bufs[(u, x)] =  interpolation_type (
762-                     get_tmp (prev_u, u), get_tmp (prev_x, x), args... )
763-             else 
764-                 interp[(u, x)]
765-             end 
766-         end 
799+     interp =  @inbounds  if  (u, x) ≠  (get_tmp (prev_u, u), get_tmp (prev_x, x))
800+         get_tmp (prev_u, u) .=  u
801+         get_tmp (prev_x, x) .=  x
802+         #  @info "cache miss"
803+         cache. bufs[(u, x)] =  interpolation_type (
804+             get_tmp (prev_u, u), get_tmp (prev_x, x), args... )
805+     else 
806+         #  @info "cache hit"
807+         cache[(u, x)]
767808    end 
809+ 
810+     return  interp
768811end 
769812
770- @register_symbolic   interpolation_builder ( 
771-     f :: Function , u :: AbstractArray , x :: AbstractArray , args :: Tuple ) 
772- interpolation_builder (f, u, x, args)  =   f (u , x, args)
813+ Base . nameof ( :: CachedInterpolation )  =   :CachedInterpolation 
814+ 
815+ @register_symbolic  (f :: CachedInterpolation )(u :: AbstractArray , x:: AbstractArray , args:: Tuple )
773816
774817""" 
775818    ParametrizedInterpolation(interp_type, u, x, args...; name) 
@@ -792,24 +835,22 @@ such as `LinearInterpolation`, `ConstantInterpolation` or `CubicSpline`.
792835  - `output`: a [`RealOutput`](@ref) connector corresponding to the interpolated value 
793836""" 
794837function  ParametrizedInterpolation (
795-         interp_type:: T , u:: AbstractVector , x:: AbstractVector , args... ; name) where  {T}
838+         interp_type:: T , u:: AbstractVector , x:: AbstractVector , args... ; name) where  T
839+ 
796840    @parameters  data[1 : length (x)] =  u
797841    @parameters  ts[1 : length (x)] =  x
798-     @parameters  interpolation_type:: T = interp_type [tunable =  false ] interpolation_args:: Tuple = args [tunable =  false ]
799-     @parameters  interpolator:: interp_type 
800- 
801-     build_interpolation =  cached_interpolation (interp_type, u, x, args)
802-     @parameters  memoized_builder:: typeof (build_interpolation)= build_interpolation [tunable =  false ]
842+     @parameters  interpolation_type:: T = interp_type [tunable =  false ]
843+     build_interpolation =  CachedInterpolation (interp_type, u, x, args)
844+     @parameters  (interpolator:: interp_type )(.. ):: eltype (u)
803845
804846    @named  output =  RealOutput ()
805847
806-     eqs =  [output. u ~  apply_interpolation ( interpolator,  t)]
848+     eqs =  [output. u ~  interpolator ( t)]
807849
808850    ODESystem (eqs, t, [],
809-         [data, ts, interpolation_type, interpolation_args,  interpolator, memoized_builder ];
851+         [data, ts, interpolation_type, interpolator];
810852        parameter_dependencies =  [
811-             interpolator =>  interpolation_builder (
812-             memoized_builder, data, ts, interpolation_args)
853+             interpolator ~  build_interpolation (data, ts, args)
813854        ],
814855        systems =  [output],
815856        name)
0 commit comments