Skip to content

Commit 951df5d

Browse files
Merge pull request #2437 from jClugstor/test_sol_stripping
add strip_solution tests
2 parents e3af641 + 8c2dd01 commit 951df5d

File tree

8 files changed

+31
-10
lines changed

8 files changed

+31
-10
lines changed

lib/OrdinaryDiffEqCore/src/alg_utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ end
191191
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
192192

193193
function alg_autodiff end
194-
has_lazy_interpolation(alg) = false
195194

196195
# Linear Exponential doesn't have any of the AD stuff
197196
function DiffEqBase.prepare_alg(

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function DiffEqBase.reeval_internals_due_to_modification!(
4545
if continuous_modification && integrator.opts.calck
4646
resize!(integrator.k, integrator.kshortsize) # Reset k for next step!
4747
alg = unwrap_alg(integrator, false)
48-
if has_lazy_interpolation(alg)
48+
if SciMLBase.has_lazy_interpolation(alg)
4949
ode_addsteps!(integrator, integrator.f, true, false, !alg.lazy)
5050
else
5151
ode_addsteps!(integrator, integrator.f, true, false)

lib/OrdinaryDiffEqCore/src/interp_func.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function SciMLBase.strip_interpolation(id::InterpolationData)
7575
end
7676

7777
function strip_cache(cache)
78-
if hasfield(typeof(cache), :jac_config) || hasfield(typeof(cache), :grad_config)
78+
if hasfield(typeof(cache), :jac_config) || hasfield(typeof(cache), :grad_config) || hasfield(typeof(cache), :nlsolver)
7979
fieldnums = length(fieldnames(typeof(cache)))
8080
noth_list = fill(nothing, fieldnums)
8181
cache_type_name = Base.typename(typeof(cache)).wrapper
@@ -84,3 +84,4 @@ function strip_cache(cache)
8484
cache
8585
end
8686
end
87+

lib/OrdinaryDiffEqLowOrderRK/src/OrdinaryDiffEqLowOrderRK.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import OrdinaryDiffEqCore: alg_order, isfsal, beta2_default, beta1_default,
1717
copyat_or_push!,
1818
AutoAlgSwitch, _ode_interpolant, _ode_interpolant!, full_cache,
1919
accept_step_controller, DerivativeOrderNotPossibleError,
20-
has_lazy_interpolation, du_cache, u_cache, get_fsalfirstlast
20+
du_cache, u_cache, get_fsalfirstlast
2121
using SciMLBase
2222
import MuladdMacro: @muladd
2323
import FastBroadcast: @..

lib/OrdinaryDiffEqLowOrderRK/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ alg_order(alg::Alshina2) = 2
2525
alg_order(alg::Alshina3) = 3
2626
alg_order(alg::Alshina6) = 6
2727

28-
has_lazy_interpolation(alg::BS5) = true
28+
SciMLBase.has_lazy_interpolation(alg::BS5) = true
2929

3030
isfsal(alg::FRK65) = true
3131
isfsal(alg::RKO65) = false

lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
1313
_ode_interpolant!, _ode_addsteps!, @fold,
1414
@OnDemandTableauExtract, AutoAlgSwitch,
1515
DerivativeOrderNotPossibleError,
16-
has_lazy_interpolation, get_fsalfirstlast
16+
get_fsalfirstlast
1717
using FastBroadcast, Polyester, MuladdMacro, RecursiveArrayTools
1818
using DiffEqBase: @def, @tight_loop_macros
1919
using Static: False

lib/OrdinaryDiffEqVerner/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ alg_stability_size(alg::Vern7) = 4.6400
1212
alg_stability_size(alg::Vern8) = 5.8641
1313
alg_stability_size(alg::Vern9) = 4.4762
1414

15-
has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9}) = true
15+
SciMLBase.has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9}) = true

test/interface/ode_strip_test.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,28 @@ u0 = [1.0; 0.0; 0.0]
1111
tspan = (0.0, 0.5)
1212
prob = ODEProblem(lorenz!, u0, tspan)
1313

14-
sol = solve(prob, Rosenbrock23())
14+
rosenbrock_sol = solve(prob, Rosenbrock23())
15+
TRBDF_sol = solve(prob, TRBDF2())
16+
vern_sol = solve(prob, Vern7())
17+
@testset "Interpolation Stripping" begin
18+
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).f)
19+
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).cache.jac_config)
20+
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).cache.grad_config)
21+
end
22+
23+
@testset "Rosenbrock Solution Stripping" begin
24+
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).prob)
25+
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).alg)
26+
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.f)
27+
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.jac_config)
28+
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.grad_config)
29+
end
30+
31+
@testset "TRBDF Solution Stripping" begin
32+
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).prob)
33+
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).alg)
34+
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.f)
35+
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.cache.nlsolver)
36+
end
1537

16-
@test isnothing(SciMLBase.strip_interpolation(sol.interp).f)
17-
@test isnothing(SciMLBase.strip_interpolation(sol.interp).cache.jac_config)
38+
@test_throws SciMLBase.LazyInterpolationException SciMLBase.strip_solution(vern_sol)

0 commit comments

Comments
 (0)