11module NewtonKrylov
22
33export newton_krylov, newton_krylov!
4+ export halley_krylov, halley_krylov!
45
56using Krylov
67using LinearAlgebra, SparseArrays
@@ -84,6 +85,11 @@ function Base.collect(JOp::JacobianOperator)
8485 return J
8586end
8687
88+ """
89+ HessianOperator
90+
91+ Calculcates H(F, u) * v * v
92+ """
8793struct HessianOperator{F, A}
8894 J:: JacobianOperator{F, A}
8995 J_cache:: JacobianOperator{F, A}
@@ -95,12 +101,15 @@ Base.eltype(H::HessianOperator) = eltype(H.J)
95101
96102function mul! (out, H:: HessianOperator , v)
97103 _out = similar (H. J. res) # TODO cache in H
98- du = Enzyme. make_zero (H. J. u) # TODO cache in H
99-
100- autodiff (Forward, mul!,
101- DuplicatedNoNeed (_out, out),
102- DuplicatedNoNeed (H. J, H. J_cache),
103- DuplicatedNoNeed (du, v))
104+ Enzyme. make_zero! (H. J_cache)
105+ H. J_cache. u .= v
106+ autodiff (
107+ Forward,
108+ mul!,
109+ DuplicatedNoNeed (_out, out),
110+ DuplicatedNoNeed (H. J, H. J_cache),
111+ Const (v)
112+ )
104113
105114 return nothing
106115end
@@ -247,11 +256,102 @@ function newton_krylov!(
247256 η = forcing (η, tol, n_res, n_res_prior)
248257 end
249258
250- verbose > 0 && @info " Newton" iter = n_res η= (forcing != = nothing ? η : nothing ) stats
259+ verbose > 0 && @info " Newton" iter = n_res η = (forcing != = nothing ? η : nothing ) stats
251260 stats = update (stats, solver. stats. niter)
252261 end
253262 t = (time_ns () - t₀) / 1.0e9
254263 return u, (; solved = n_res <= tol, stats, t)
255264end
256265
266+ function halley_krylov (F, u₀:: AbstractArray , M:: Int = length (u₀); kwargs... )
267+ F! (res, u) = (res .= F (u); nothing )
268+ return halley_krylov! (F!, u₀, M; kwargs... )
269+ end
270+
271+ function halley_krylov! (F!, u₀:: AbstractArray , M:: Int = length (u₀); kwargs... )
272+ res = similar (u₀, M)
273+ return halley_krylov! (F!, u₀, res; kwargs... )
274+ end
275+
276+ function halley_krylov! (
277+ F!, u:: AbstractArray , res:: AbstractArray ;
278+ tol_rel = 1.0e-6 ,
279+ tol_abs = 1.0e-12 ,
280+ max_niter = 50 ,
281+ forcing:: Union{Forcing, Nothing} = EisenstatWalker (),
282+ verbose = 0 ,
283+ Solver = GmresSolver,
284+ M = nothing ,
285+ N = nothing ,
286+ krylov_kwargs = (;),
287+ callback = (args... ) -> nothing ,
288+ )
289+ t₀ = time_ns ()
290+ F! (res, u) # res = F(u)
291+ n_res = norm (res)
292+ callback (u, res, n_res)
293+
294+ tol = tol_rel * n_res + tol_abs
295+
296+ if forcing != = nothing
297+ η = inital (forcing)
298+ end
299+
300+ verbose > 0 && @info " Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
301+
302+ J = JacobianOperator (F!, res, u)
303+ H = HessianOperator (J)
304+ solver = Solver (J, res)
305+
306+ stats = Stats (0 , 0 )
307+ while n_res > tol && stats. outer_iterations <= max_niter
308+ # Handle kwargs for Preconditoners
309+ kwargs = krylov_kwargs
310+ if N != = nothing
311+ kwargs = (; N = N (J), kwargs... )
312+ end
313+ if M != = nothing
314+ kwargs = (; M = M (J), kwargs... )
315+ end
316+ if forcing != = nothing
317+ # ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
318+ kwargs = (; rtol = η, kwargs... )
319+ end
320+
321+ solve! (solver, J, copy (res); kwargs... ) # J \ fx
322+ a = copy (solver. x)
323+
324+ # calculate hvvp (2nd order directional derivative using the JVP)
325+ hvvp = similar (res)
326+ mul! (hvvp, H, a)
327+
328+ solve! (solver, J, hvvp; kwargs... ) # J \ hvvp
329+ b = solver. x
330+
331+ # Update u
332+ @. u -= (a * a) / (a - b / 2 )
333+
334+ # Update residual and norm
335+ n_res_prior = n_res
336+
337+ F! (res, u) # res = F(u)
338+ n_res = norm (res)
339+ callback (u, res, n_res)
340+
341+ if isinf (n_res) || isnan (n_res)
342+ @error " Inner solver blew up" stats
343+ break
344+ end
345+
346+ if forcing != = nothing
347+ η = forcing (η, tol, n_res, n_res_prior)
348+ end
349+
350+ verbose > 0 && @info " Newton" iter = n_res η = (forcing != = nothing ? η : nothing ) stats
351+ stats = update (stats, solver. stats. niter) # TODO we do two calls to solver iterations
352+ end
353+ t = (time_ns () - t₀) / 1.0e9
354+ return u, (; solved = n_res <= tol, stats, t)
355+ end
356+
257357end # module NewtonKrylov
0 commit comments