@@ -2,10 +2,6 @@ module SimpleImplicitDiscreteSolve
22
33using SciMLBase
44using SimpleNonlinearSolve
5- using UnPack
6- using SymbolicIndexingInterface: parameter_symbols
7- import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, alg_cache, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, get_fsalfirstlast, isfsal, initialize!, perform_step!, isdiscretecache, isdiscretealg, alg_order, beta2_default, beta1_default, dt_required, _initialize_dae!, DefaultInit, BrownFullBasicInit, OverrideInit
8-
95using Reexport
106@reexport using DiffEqBase
117
@@ -14,11 +10,85 @@ using Reexport
1410
1511Simple solver for `ImplicitDiscreteSystems`. Uses `SimpleNewtonRaphson` to solve for the next state at every timestep.
1612"""
17- struct SimpleIDSolve <: OrdinaryDiffEqAlgorithm end
13+ struct SimpleIDSolve <: SciMLBase.AbstractODEAlgorithm end
14+
15+ function DiffEqBase. __init (prob:: ImplicitDiscreteProblem , alg:: SimpleIDSolve ; dt = 1 )
16+ u0 = prob. u0
17+ p = prob. p
18+ f = prob. f
19+ t = prob. tspan[1 ]
20+
21+ nlf = isinplace (f) ? (out, u, p) -> f (out, u, u0, p, t) : (u, p) -> f (u, u0, p, t)
22+ prob = NonlinearProblem {isinplace(f)} (nlf, u0, p)
23+ sol = solve (prob, SimpleNewtonRaphson ())
24+ sol, (sol. retcode != ReturnCode. Success)
25+ end
26+
27+ function DiffEqBase. solve (prob:: ImplicitDiscreteProblem , alg:: SimpleIDSolve ;
28+ dt = 1 ,
29+ save_everystep = true ,
30+ save_start = true ,
31+ adaptive = false ,
32+ dense = false ,
33+ save_end = true ,
34+ kwargs... )
35+
36+ @assert ! adaptive
37+ @assert ! dense
38+ (initsol, initfail) = DiffEqBase. __init (prob, alg; dt)
39+ if initfail
40+ sol = DiffEqBase. build_solution (prob, alg, prob. tspan[1 ], u0, k = nothing , stats = nothing , calculate_error = false )
41+ return SciMLBase. solution_new_retcode (sol, ReturnCode. InitialFailure)
42+ end
1843
19- include (" cache.jl" )
20- include (" solve.jl" )
21- include (" alg_utils.jl" )
44+ u0 = initsol. u
45+ tspan = prob. tspan
46+ f = prob. f
47+ p = prob. p
48+ t = tspan[1 ]
49+ tf = prob. tspan[2 ]
50+ ts = tspan[1 ]: dt: tspan[2 ]
51+
52+ if save_everystep && save_start
53+ us = Vector {typeof(u0)} (undef, length (ts))
54+ us[1 ] = u0
55+ elseif save_everystep
56+ us = Vector {typeof(u0)} (undef, length (ts) - 1 )
57+ elseif save_start
58+ us = Vector {typeof(u0)} (undef, 2 )
59+ us[1 ] = u0
60+ else
61+ us = Vector {typeof(u0)} (undef, 1 ) # for interface compatibility
62+ end
63+
64+ u = u0
65+ convfail = false
66+ for i in 2 : length (ts)
67+ uprev = u
68+ t = ts[i]
69+ nlf = isinplace (f) ? (out, u, p) -> f (out, u, uprev, p, t) : (u, p) -> f (u, uprev, p, t)
70+ nlprob = NonlinearProblem {isinplace(f)} (nlf, uprev, p)
71+ nlsol = solve (nlprob, SimpleNewtonRaphson ())
72+ u = nlsol. u
73+ save_everystep && (us[i] = u)
74+ convfail = (nlsol. retcode != ReturnCode. Success)
75+
76+ if convfail
77+ sol = DiffEqBase. build_solution (prob, alg, ts[1 : i], us[1 : i], k = nothing , stats = nothing , calculate_error = false )
78+ sol = SciMLBase. solution_new_retcode (sol, ReturnCode. ConvergenceFailure)
79+ return sol
80+ end
81+ end
82+
83+ ! save_everystep && save_end && (us[end ] = u)
84+ sol = DiffEqBase. build_solution (prob, alg, ts, us,
85+ k = nothing , stats = nothing ,
86+ calculate_error = false )
87+
88+ DiffEqBase. has_analytic (prob. f) &&
89+ DiffEqBase. calculate_solution_errors! (sol; timeseries_errors = true , dense_errors = false )
90+ sol
91+ end
2292
2393export SimpleIDSolve
2494
0 commit comments