@@ -19,33 +19,54 @@ solves.
1919 - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
2020 which means that no line search is performed. Algorithms from `LineSearches.jl` can be
2121 used here directly, and they will be converted to the correct `LineSearch`.
22+ - `init_jacobian`: the method to use for initializing the jacobian. Defaults to using the
23+ identity matrix (`Val(:identitiy)`). Alternatively, can be set to `Val(:true_jacobian)`
24+ to use the true jacobian as initialization. (Our tests suggest it is a good idea to
25+ to initialize with an identity matrix)
26+ - `autodiff`: determines the backend used for the Jacobian. Note that this argument is
27+ ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
28+ `nothing` which means that a default is selected according to the problem specification!
29+ Valid choices are types from ADTypes.jl. (Used if `init_jacobian = Val(:true_jacobian)`)
2230"""
23- @concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
31+ @concrete struct GeneralKlement{IJ, CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
32+ ad:: AD
2433 max_resets:: Int
2534 linsolve
2635 precs
2736 linesearch
2837end
2938
30- function set_linsolve (alg:: GeneralKlement , linsolve)
31- return GeneralKlement (alg. max_resets, linsolve, alg. precs, alg. linesearch)
39+ function set_linsolve (alg:: GeneralKlement{IJ, CS} , linsolve) where {IJ, CS}
40+ return GeneralKlement {IJ, CS} (alg. ad, alg. max_resets, linsolve, alg. precs,
41+ alg. linesearch)
42+ end
43+
44+ function set_ad (alg:: GeneralKlement{IJ, CS} , ad) where {IJ, CS}
45+ return GeneralKlement {IJ, CS} (ad, alg. max_resets, alg. linsolve, alg. precs,
46+ alg. linesearch)
3247end
3348
3449function GeneralKlement (; max_resets:: Int = 5 , linsolve = nothing ,
35- linesearch = nothing , precs = DEFAULT_PRECS)
50+ linesearch = nothing , precs = DEFAULT_PRECS, init_jacobian:: Val = Val (:identity ),
51+ autodiff = nothing )
52+ IJ = _unwrap_val (init_jacobian)
53+ @assert IJ ∈ (:identity , :true_jacobian )
3654 linesearch = linesearch isa LineSearch ? linesearch : LineSearch (; method = linesearch)
37- return GeneralKlement (max_resets, linsolve, precs, linesearch)
55+ CJ = IJ === :true_jacobian
56+ return GeneralKlement {IJ, CJ} (autodiff, max_resets, linsolve, precs, linesearch)
3857end
3958
40- @concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
59+ @concrete mutable struct GeneralKlementCache{iip, IJ } <: AbstractNonlinearSolveCache{iip}
4160 f
4261 alg
4362 u
4463 u_cache
4564 fu
4665 fu_cache
66+ fu_cache_2
4767 du
4868 p
69+ uf
4970 linsolve
5071 J
5172 J_cache
6081 abstol
6182 reltol
6283 prob
84+ jac_cache
6385 stats:: NLStats
6486 ls_cache
6587 tc_cache
6688 trace
6789end
6890
69- function SciMLBase. __init (prob:: NonlinearProblem{uType, iip} , alg_:: GeneralKlement , args ... ;
70- alias_u0 = false , maxiters = 1000 , abstol = nothing , reltol = nothing ,
91+ function SciMLBase. __init (prob:: NonlinearProblem{uType, iip} , alg_:: GeneralKlement{IJ} ,
92+ args ... ; alias_u0 = false , maxiters = 1000 , abstol = nothing , reltol = nothing ,
7193 termination_condition = nothing , internalnorm:: F = DEFAULT_NORM,
72- linsolve_kwargs = (;), kwargs... ) where {uType, iip, F}
94+ linsolve_kwargs = (;), kwargs... ) where {uType, iip, F, IJ }
7395 @unpack f, u0, p = prob
7496 u = __maybe_unaliased (u0, alias_u0)
7597 fu = evaluate_f (prob, u)
76- J = __init_identity_jacobian (u, fu)
77- @bb du = similar (u)
7898
79- if u isa Number
80- linsolve = FakeLinearSolveJLCache (J, fu)
99+ if IJ === :true_jacobian
100+ alg = get_concrete_algorithm (alg_, prob)
101+ uf, _, J, fu_cache, jac_cache, du = jacobian_caches (alg, f, u, p, Val (iip);
102+ lininit = Val (false ))
103+ elseif IJ === :identity
81104 alg = alg_
105+ @bb du = similar (u)
106+ uf, fu_cache, jac_cache = nothing , nothing , nothing
107+ J = one .(u) # Identity Init Jacobian for Klement maintains a Diagonal Structure
108+ else
109+ error (" Invalid `init_jacobian` value" )
110+ end
111+
112+ if IJ === :true_jacobian
113+ linsolve = linsolve_caches (J, _vec (fu), _vec (du), p, alg_; linsolve_kwargs)
82114 else
83- # For General Julia Arrays default to LU Factorization
84- linsolve_alg = (alg_. linsolve === nothing && (u isa Array || u isa StaticArray)) ?
85- LUFactorization () : nothing
86- alg = set_linsolve (alg_, linsolve_alg)
87- linsolve = linsolve_caches (J, _vec (fu), _vec (du), p, alg; linsolve_kwargs)
115+ linsolve = nothing
88116 end
89117
90118 abstol, reltol, tc_cache = init_termination_cache (abstol, reltol, fu, u,
91119 termination_condition)
92120 trace = init_nonlinearsolve_trace (alg, u, fu, J, du; kwargs... )
93121
94122 @bb u_cache = copy (u)
95- @bb fu_cache = copy (fu)
96- @bb J_cache = similar (J)
97- @bb J_cache_2 = similar (J)
123+ @bb fu_cache_2 = copy (fu)
98124 @bb Jdu = similar (fu)
99- @bb Jdu_cache = similar (fu)
125+ if IJ === :true_jacobian
126+ @bb J_cache = similar (J)
127+ @bb J_cache_2 = similar (J)
128+ @bb Jdu_cache = similar (fu)
129+ else
130+ J_cache, J_cache_2, Jdu_cache = nothing , nothing , nothing
131+ end
100132
101- return GeneralKlementCache {iip} (f, alg, u, u_cache, fu, fu_cache, du, p, linsolve,
102- J, J_cache, J_cache_2, Jdu, Jdu_cache, 0 , false , maxiters, internalnorm,
103- ReturnCode. Default, abstol, reltol, prob, NLStats (1 , 0 , 0 , 0 , 0 ),
133+ return GeneralKlementCache {iip, IJ} (f, alg, u, u_cache, fu, fu_cache, fu_cache_2, du, p,
134+ uf, linsolve, J, J_cache, J_cache_2, Jdu, Jdu_cache, 0 , false , maxiters,
135+ internalnorm,
136+ ReturnCode. Default, abstol, reltol, prob, jac_cache, NLStats (1 , 0 , 0 , 0 , 0 ),
104137 init_linesearch_cache (alg. linesearch, f, u, p, fu, Val (iip)), tc_cache, trace)
105138end
106139
107- function perform_step! (cache:: GeneralKlementCache{iip} ) where {iip}
140+ function perform_step! (cache:: GeneralKlementCache{iip, IJ } ) where {iip, IJ }
108141 @unpack linsolve, alg = cache
109142 T = eltype (cache. J)
110- singular, fact_done = __try_factorize_and_check_singular! (linsolve, cache. J)
111143
112- if singular
144+ if IJ === :true_jacobian
145+ cache. stats. nsteps == 0 && (cache. J = jacobian!! (cache. J, cache))
146+ ill_conditioned = __is_ill_conditioned (cache. J)
147+ elseif IJ === :identity
148+ ill_conditioned = __is_ill_conditioned (cache. J)
149+ end
150+
151+ if ill_conditioned
113152 if cache. resets == alg. max_resets
114153 cache. force_stop = true
115154 cache. retcode = ReturnCode. ConvergenceFailure
116155 return nothing
117156 end
118- fact_done = false
119- cache. J = __reinit_identity_jacobian!! (cache. J)
157+ if IJ === :true_jacobian && cache. stats. nsteps != 0
158+ cache. J = jacobian!! (cache. J, cache)
159+ else
160+ cache. J = __reinit_identity_jacobian!! (cache. J)
161+ end
120162 cache. resets += 1
121163 end
122164
123- A = ifelse (cache. J isa SMatrix || cache. J isa Number || ! fact_done, cache. J, nothing )
124-
125- # u = u - J \ fu
126- linres = dolinsolve (cache, alg. precs, cache. linsolve; A,
127- b = _vec (cache. fu), linu = _vec (cache. du), cache. p, reltol = cache. abstol)
128- cache. linsolve = linres. cache
129- cache. du = _restructure (cache. du, linres. u)
165+ if IJ === :identity
166+ @bb @. cache. du = cache. fu / cache. J
167+ else
168+ # u = u - J \ fu
169+ linres = dolinsolve (cache, alg. precs, cache. linsolve; A = cache. J,
170+ b = _vec (cache. fu), linu = _vec (cache. du), cache. p, reltol = cache. abstol)
171+ cache. linsolve = linres. cache
172+ cache. du = _restructure (cache. du, linres. u)
173+ end
130174
131175 # Line Search
132176 α = perform_linesearch! (cache. ls_cache, cache. u, cache. du)
@@ -143,18 +187,25 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
143187
144188 # Update the Jacobian
145189 @bb cache. du .*= - 1
146- @bb cache. J_cache .= cache. J' .^ 2
147- @bb @. cache. Jdu = cache. du^ 2
148- @bb cache. Jdu_cache = cache. J_cache × vec (cache. Jdu)
149- @bb cache. Jdu = cache. J × vec (cache. du)
150- @bb @. cache. fu_cache = (cache. fu - cache. fu_cache - cache. Jdu) /
151- ifelse (iszero (cache. Jdu_cache), T (1e-5 ), cache. Jdu_cache)
152- @bb cache. J_cache = vec (cache. fu_cache) × transpose (_vec (cache. du))
153- @bb @. cache. J_cache *= cache. J
154- @bb cache. J_cache_2 = cache. J_cache × cache. J
155- @bb cache. J .+ = cache. J_cache_2
156-
157- @bb copyto! (cache. fu_cache, cache. fu)
190+ if IJ === :identity
191+ @bb @. cache. Jdu = (cache. J^ 2 ) * (cache. du^ 2 )
192+ @bb @. cache. J += ((cache. fu - cache. fu_cache_2 - cache. J * cache. du) /
193+ ifelse (iszero (cache. Jdu), T (1e-5 ), cache. Jdu)) * cache. du *
194+ (cache. J^ 2 )
195+ else
196+ @bb cache. J_cache .= cache. J' .^ 2
197+ @bb @. cache. Jdu = cache. du^ 2
198+ @bb cache. Jdu_cache = cache. J_cache × vec (cache. Jdu)
199+ @bb cache. Jdu = cache. J × vec (cache. du)
200+ @bb @. cache. fu_cache_2 = (cache. fu - cache. fu_cache_2 - cache. Jdu) /
201+ ifelse (iszero (cache. Jdu_cache), T (1e-5 ), cache. Jdu_cache)
202+ @bb cache. J_cache = vec (cache. fu_cache_2) × transpose (_vec (cache. du))
203+ @bb @. cache. J_cache *= cache. J
204+ @bb cache. J_cache_2 = cache. J_cache × cache. J
205+ @bb cache. J .+ = cache. J_cache_2
206+ end
207+
208+ @bb copyto! (cache. fu_cache_2, cache. fu)
158209
159210 return nothing
160211end
0 commit comments