|
1 | 1 | using DiffEqBase |
2 | 2 | import ChainRulesCore |
| 3 | +using PreallocationTools |
3 | 4 |
|
4 | 5 | # Define and register smooth functions |
5 | 6 | # These are "smooth" aka differentiable and avoid Gibbs effect |
@@ -739,22 +740,28 @@ function Symbolics.derivative(::typeof(apply_interpolation), args::NTuple{2, Any |
739 | 740 | end |
740 | 741 |
|
741 | 742 | function cached_interpolation(interpolation_type, u, x, args) |
742 | | - prev_u = Ref(collect(copy(u))) |
743 | | - prev_x = Ref(collect(copy(x))) |
744 | | - # MTKParameters use views, so we want to ensure that the type is the same |
745 | | - interp = Ref(interpolation_type(prev_u[], prev_x[], args...)) |
| 743 | + prev_u = DiffCache(u) |
| 744 | + # Interpolation points can be a range, but we want to be able |
| 745 | + # to update the cache if needed (and setindex! is not defined on ranges) |
| 746 | + # with a view from MTKParameters, so we collect to get a vector |
| 747 | + prev_x = DiffCache(collect(x)) |
| 748 | + interp = GeneralLazyBufferCache() do (u,x) |
| 749 | + interpolation_type(get_tmp(prev_u, u), get_tmp(prev_x, x), args...) |
| 750 | + end |
746 | 751 |
|
747 | 752 | let prev_u = prev_u, |
748 | 753 | prev_x = prev_x, |
749 | 754 | interp = interp, |
750 | 755 | interpolation_type = interpolation_type |
| 756 | + |
751 | 757 | function build_interpolation(u, x, args) |
752 | | - if (u, x) ≠ (prev_u[], prev_x[]) |
753 | | - prev_u[] = collect(u) |
754 | | - prev_x[] = collect(x) |
755 | | - interp[] = interpolation_type(prev_u[], prev_x[], args...) |
| 758 | + if (u, x) ≠ (get_tmp(prev_u, u), get_tmp(prev_x, x)) |
| 759 | + get_tmp(prev_u, u) .= u |
| 760 | + get_tmp(prev_x, x) .= x |
| 761 | + interp.bufs[(u,x)] = interpolation_type( |
| 762 | + get_tmp(prev_u, u), get_tmp(prev_x, x), args...) |
756 | 763 | else |
757 | | - interp[] |
| 764 | + interp[(u,x)] |
758 | 765 | end |
759 | 766 | end |
760 | 767 | end |
|
0 commit comments