Skip to content

Commit e3d63c2

Browse files
committed
refactor: make cached_interpolation work with Duals
1 parent ac9b145 commit e3d63c2

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
12+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
1213
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1314

1415
[weakdeps]
@@ -27,6 +28,7 @@ IfElse = "0.1"
2728
LinearAlgebra = "1.10"
2829
ModelingToolkit = "9.28"
2930
OrdinaryDiffEq = "6.87"
31+
PreallocationTools = "0.4.23"
3032
SafeTestsets = "0.1"
3133
Symbolics = "5.35.1, 6"
3234
Test = "1"

src/Blocks/sources.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DiffEqBase
22
import ChainRulesCore
3+
using PreallocationTools
34

45
# Define and register smooth functions
56
# These are "smooth" aka differentiable and avoid Gibbs effect
@@ -739,22 +740,28 @@ function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any
739740
end
740741

741742
function cached_interpolation(interpolation_type, u, x, args)
742-
prev_u = Ref(collect(copy(u)))
743-
prev_x = Ref(collect(copy(x)))
744-
# MTKParameters use views, so we want to ensure that the type is the same
745-
interp = Ref(interpolation_type(prev_u[], prev_x[], args...))
743+
prev_u = DiffCache(u)
744+
# Interpolation points can be a range, but we want to be able
745+
# to update the cache if needed (and setindex! is not defined on ranges)
746+
# with a view from MTKParameters, so we collect to get a vector
747+
prev_x = DiffCache(collect(x))
748+
interp = GeneralLazyBufferCache() do (u,x)
749+
interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
750+
end
746751

747752
let prev_u = prev_u,
748753
prev_x = prev_x,
749754
interp = interp,
750755
interpolation_type = interpolation_type
756+
751757
function build_interpolation(u, x, args)
752-
if (u, x) (prev_u[], prev_x[])
753-
prev_u[] = collect(u)
754-
prev_x[] = collect(x)
755-
interp[] = interpolation_type(prev_u[], prev_x[], args...)
758+
if (u, x) (get_tmp(prev_u, u), get_tmp(prev_x, x))
759+
get_tmp(prev_u, u) .= u
760+
get_tmp(prev_x, x) .= x
761+
interp.bufs[(u,x)] = interpolation_type(
762+
get_tmp(prev_u, u), get_tmp(prev_x, x), args...)
756763
else
757-
interp[]
764+
interp[(u,x)]
758765
end
759766
end
760767
end

0 commit comments

Comments
 (0)