11module NewtonKrylov
22
33export newton_krylov, newton_krylov!
4+ export halley_krylov, halley_krylov!
45
56using Krylov
67using LinearAlgebra, SparseArrays
@@ -89,6 +90,11 @@ function Base.collect(JOp::JacobianOperator)
8990 return J
9091end
9192
93+ """
94+ HessianOperator
95+
96+ Calculcates H(F, u) * v * v
97+ """
9298struct HessianOperator{F, A}
9399 J:: JacobianOperator{F, A}
94100 J_cache:: JacobianOperator{F, A}
@@ -100,12 +106,15 @@ Base.eltype(H::HessianOperator) = eltype(H.J)
100106
101107function mul! (out, H:: HessianOperator , v)
102108 _out = similar (H. J. res) # TODO cache in H
103- du = Enzyme. make_zero (H. J. u) # TODO cache in H
104-
105- autodiff (Forward, mul!,
106- DuplicatedNoNeed (_out, out),
107- DuplicatedNoNeed (H. J, H. J_cache),
108- DuplicatedNoNeed (du, v))
109+ Enzyme. make_zero! (H. J_cache)
110+ H. J_cache. u .= v
111+ autodiff (
112+ Forward,
113+ mul!,
114+ DuplicatedNoNeed (_out, out),
115+ DuplicatedNoNeed (H. J, H. J_cache),
116+ Const (v)
117+ )
109118
110119 return nothing
111120end
@@ -299,11 +308,102 @@ function newton_krylov!(
299308 η = forcing (η, tol, n_res, n_res_prior)
300309 end
301310
302- verbose > 0 && @info " Newton" iter = n_res η= (forcing != = nothing ? η : nothing ) stats
311+ verbose > 0 && @info " Newton" iter = n_res η = (forcing != = nothing ? η : nothing ) stats
303312 stats = update (stats, solver. stats. niter)
304313 end
305314 t = (time_ns () - t₀) / 1.0e9
306315 return u, (; solved = n_res <= tol, stats, t)
307316end
308317
318+ function halley_krylov (F, u₀:: AbstractArray , M:: Int = length (u₀); kwargs... )
319+ F! (res, u) = (res .= F (u); nothing )
320+ return halley_krylov! (F!, u₀, M; kwargs... )
321+ end
322+
323+ function halley_krylov! (F!, u₀:: AbstractArray , M:: Int = length (u₀); kwargs... )
324+ res = similar (u₀, M)
325+ return halley_krylov! (F!, u₀, res; kwargs... )
326+ end
327+
328+ function halley_krylov! (
329+ F!, u:: AbstractArray , res:: AbstractArray ;
330+ tol_rel = 1.0e-6 ,
331+ tol_abs = 1.0e-12 ,
332+ max_niter = 50 ,
333+ forcing:: Union{Forcing, Nothing} = EisenstatWalker (),
334+ verbose = 0 ,
335+ Solver = GmresSolver,
336+ M = nothing ,
337+ N = nothing ,
338+ krylov_kwargs = (;),
339+ callback = (args... ) -> nothing ,
340+ )
341+ t₀ = time_ns ()
342+ F! (res, u) # res = F(u)
343+ n_res = norm (res)
344+ callback (u, res, n_res)
345+
346+ tol = tol_rel * n_res + tol_abs
347+
348+ if forcing != = nothing
349+ η = inital (forcing)
350+ end
351+
352+ verbose > 0 && @info " Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
353+
354+ J = JacobianOperator (F!, res, u)
355+ H = HessianOperator (J)
356+ solver = Solver (J, res)
357+
358+ stats = Stats (0 , 0 )
359+ while n_res > tol && stats. outer_iterations <= max_niter
360+ # Handle kwargs for Preconditoners
361+ kwargs = krylov_kwargs
362+ if N != = nothing
363+ kwargs = (; N = N (J), kwargs... )
364+ end
365+ if M != = nothing
366+ kwargs = (; M = M (J), kwargs... )
367+ end
368+ if forcing != = nothing
369+ # ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
370+ kwargs = (; rtol = η, kwargs... )
371+ end
372+
373+ solve! (solver, J, copy (res); kwargs... ) # J \ fx
374+ a = copy (solver. x)
375+
376+ # calculate hvvp (2nd order directional derivative using the JVP)
377+ hvvp = similar (res)
378+ mul! (hvvp, H, a)
379+
380+ solve! (solver, J, hvvp; kwargs... ) # J \ hvvp
381+ b = solver. x
382+
383+ # Update u
384+ @. u -= (a * a) / (a - b / 2 )
385+
386+ # Update residual and norm
387+ n_res_prior = n_res
388+
389+ F! (res, u) # res = F(u)
390+ n_res = norm (res)
391+ callback (u, res, n_res)
392+
393+ if isinf (n_res) || isnan (n_res)
394+ @error " Inner solver blew up" stats
395+ break
396+ end
397+
398+ if forcing != = nothing
399+ η = forcing (η, tol, n_res, n_res_prior)
400+ end
401+
402+ verbose > 0 && @info " Newton" iter = n_res η = (forcing != = nothing ? η : nothing ) stats
403+ stats = update (stats, solver. stats. niter) # TODO we do two calls to solver iterations
404+ end
405+ t = (time_ns () - t₀) / 1.0e9
406+ return u, (; solved = n_res <= tol, stats, t)
407+ end
408+
309409end # module NewtonKrylov
0 commit comments