Skip to content

Commit 7d61a05

Browse files
committed
refactor: make use of callable parameters to simplify the implementation
1 parent 1c5a1af commit 7d61a05

File tree

1 file changed

+81
-40
lines changed

1 file changed

+81
-40
lines changed

src/Blocks/sources.jl

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -733,43 +733,86 @@ function SampledData(; name, buffer, sample_time, circular_buffer)
733733
end
734734

735735
# This needs to be extend for interpolation types
736-
apply_interpolation(interp, t) = interp(t)
736+
# apply_interpolation(interp, t) = interp(t)
737+
738+
# function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any}, ::Val{2})
739+
# Symbolics.derivative(args[1], (args[2],), Val(1))
740+
# end
741+
742+
# function cached_interpolation(interpolation_type, u, x, args)
743+
# prev_u = DiffCache(u)
744+
# # Interpolation points can be a range, but we want to be able
745+
# # to update the cache if needed (and setindex! is not defined on ranges)
746+
# # with a view from MTKParameters, so we collect to get a vector
747+
# prev_x = DiffCache(collect(x))
748+
# interp = GeneralLazyBufferCache() do (u, x)
749+
# interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
750+
# end
751+
752+
# let prev_u = prev_u,
753+
# prev_x = prev_x,
754+
# interp = interp,
755+
# interpolation_type = interpolation_type
756+
757+
# function build_interpolation(u, x, args)
758+
# if (u, x) ≠ (get_tmp(prev_u, u), get_tmp(prev_x, x))
759+
# get_tmp(prev_u, u) .= u
760+
# get_tmp(prev_x, x) .= x
761+
# interp.bufs[(u, x)] = interpolation_type(
762+
# get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
763+
# else
764+
# interp[(u, x)]
765+
# end
766+
# end
767+
# end
768+
# end
769+
770+
struct CachedInterpolation{T,I,U,X,C}
771+
interpolation_type::I
772+
prev_u::U
773+
prev_x::X
774+
cache::C
775+
776+
function CachedInterpolation(interpolation_type, u, x, args)
777+
# we need to copy the inputs to avoid aliasing
778+
prev_u = DiffCache(copy(u))
779+
# Interpolation points can be a range, but we want to be able
780+
# to update the cache if needed (and setindex! is not defined on ranges)
781+
# with a view from MTKParameters, so we collect to get a vector
782+
prev_x = DiffCache(collect(copy(x)))
783+
cache = GeneralLazyBufferCache() do (u, x)
784+
interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
785+
end
786+
T = typeof(cache[(get_tmp(prev_u, u), get_tmp(prev_x, x))])
787+
I = typeof(interpolation_type)
788+
U = typeof(prev_u)
789+
X = typeof(prev_x)
790+
C = typeof(cache)
737791

738-
function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any}, ::Val{2})
739-
Symbolics.derivative(args[1], (args[2],), Val(1))
792+
new{T,I,U,X,C}(interpolation_type, prev_u, prev_x, cache)
793+
end
740794
end
741795

742-
function cached_interpolation(interpolation_type, u, x, args)
743-
prev_u = DiffCache(u)
744-
# Interpolation points can be a range, but we want to be able
745-
# to update the cache if needed (and setindex! is not defined on ranges)
746-
# with a view from MTKParameters, so we collect to get a vector
747-
prev_x = DiffCache(collect(x))
748-
interp = GeneralLazyBufferCache() do (u, x)
749-
interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
750-
end
796+
function (f::CachedInterpolation{T})(u, x, args) where T
797+
(;prev_u, prev_x, cache, interpolation_type) = f
751798

752-
let prev_u = prev_u,
753-
prev_x = prev_x,
754-
interp = interp,
755-
interpolation_type = interpolation_type
756-
757-
function build_interpolation(u, x, args)
758-
if (u, x) (get_tmp(prev_u, u), get_tmp(prev_x, x))
759-
get_tmp(prev_u, u) .= u
760-
get_tmp(prev_x, x) .= x
761-
interp.bufs[(u, x)] = interpolation_type(
762-
get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
763-
else
764-
interp[(u, x)]
765-
end
766-
end
799+
interp = @inbounds if (u, x) (get_tmp(prev_u, u), get_tmp(prev_x, x))
800+
get_tmp(prev_u, u) .= u
801+
get_tmp(prev_x, x) .= x
802+
# @info "cache miss"
803+
cache.bufs[(u, x)] = interpolation_type(
804+
get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
805+
else
806+
# @info "cache hit"
807+
cache[(u, x)]
767808
end
809+
810+
return interp
768811
end
769812

770-
@register_symbolic interpolation_builder(
771-
f::Function, u::AbstractArray, x::AbstractArray, args::Tuple)
772-
interpolation_builder(f, u, x, args) = f(u, x, args)
813+
Base.nameof(::CachedInterpolation) = :CachedInterpolation
814+
815+
@register_symbolic (f::CachedInterpolation)(u::AbstractArray, x::AbstractArray, args::Tuple)
773816

774817
"""
775818
ParametrizedInterpolation(interp_type, u, x, args...; name)
@@ -792,24 +835,22 @@ such as `LinearInterpolation`, `ConstantInterpolation` or `CubicSpline`.
792835
- `output`: a [`RealOutput`](@ref) connector corresponding to the interpolated value
793836
"""
794837
function ParametrizedInterpolation(
795-
interp_type::T, u::AbstractVector, x::AbstractVector, args...; name) where {T}
838+
interp_type::T, u::AbstractVector, x::AbstractVector, args...; name) where T
839+
796840
@parameters data[1:length(x)] = u
797841
@parameters ts[1:length(x)] = x
798-
@parameters interpolation_type::T=interp_type [tunable = false] interpolation_args::Tuple=args [tunable = false]
799-
@parameters interpolator::interp_type
800-
801-
build_interpolation = cached_interpolation(interp_type, u, x, args)
802-
@parameters memoized_builder::typeof(build_interpolation)=build_interpolation [tunable = false]
842+
@parameters interpolation_type::T=interp_type [tunable = false]
843+
build_interpolation = CachedInterpolation(interp_type, u, x, args)
844+
@parameters (interpolator::interp_type)(..)::eltype(u)
803845

804846
@named output = RealOutput()
805847

806-
eqs = [output.u ~ apply_interpolation(interpolator, t)]
848+
eqs = [output.u ~ interpolator(t)]
807849

808850
ODESystem(eqs, t, [],
809-
[data, ts, interpolation_type, interpolation_args, interpolator, memoized_builder];
851+
[data, ts, interpolation_type, interpolator];
810852
parameter_dependencies = [
811-
interpolator => interpolation_builder(
812-
memoized_builder, data, ts, interpolation_args)
853+
interpolator ~ build_interpolation(data, ts, args)
813854
],
814855
systems = [output],
815856
name)

0 commit comments

Comments
 (0)