22
33using NewtonKrylov, LinearAlgebra
44using CairoMakie
5- using Krylov, Enzyme
65
76function F! (res, x)
87 res[1 ] = x[1 ]^ 2 + x[2 ]^ 2 - 2
@@ -15,68 +14,89 @@ function F(x)
1514 return res
1615end
1716
18- function halley (F!, u, res;
17+ import NewtonKrylov: Forcing, EisenstatWalker, inital, forcing, solve!,
18+ JacobianOperator, HessianOperator, Stats, update, GmresSolver
19+
20+ function halley_krylov! (
21+ F!, u:: AbstractArray , res:: AbstractArray ;
1922 tol_rel = 1.0e-6 ,
2023 tol_abs = 1.0e-12 ,
2124 max_niter = 50 ,
25+ forcing:: Union{Forcing, Nothing} = EisenstatWalker (),
26+ verbose = 0 ,
2227 Solver = GmresSolver,
28+ M = nothing ,
29+ N = nothing ,
30+ krylov_kwargs = (;),
31+ callback = (args... ) -> nothing ,
2332 )
24-
33+ t₀ = time_ns ()
2534 F! (res, u) # res = F(u)
2635 n_res = norm (res)
36+ callback (u, res, n_res)
2737
2838 tol = tol_rel * n_res + tol_abs
2939
30- J = NewtonKrylov. JacobianOperator (F!, res, u)
31- H = NewtonKrylov. HessianOperator (J)
40+ if forcing != = nothing
41+ η = inital (forcing)
42+ end
43+
44+ verbose > 0 && @info " Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
45+
46+ J = JacobianOperator (F!, res, u)
47+ H = HessianOperator (J)
3248 solver = Solver (J, res)
3349
34- for i in :max_niter
35- if n_res <= tol
36- break
50+ stats = Stats (0 , 0 )
51+ while n_res > tol && stats. outer_iterations <= max_niter
52+ # Handle kwargs for Preconditoners
53+ kwargs = krylov_kwargs
54+ if N != = nothing
55+ kwargs = (; N = N (J), kwargs... )
3756 end
38- solve! (solver, J, copy (res)) # J \ fx
57+ if M != = nothing
58+ kwargs = (; M = M (J), kwargs... )
59+ end
60+ if forcing != = nothing
61+ # ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
62+ kwargs = (; rtol = η, kwargs... )
63+ end
64+
65+ solve! (solver, J, copy (res); kwargs... ) # J \ fx
3966 a = copy (solver. x)
4067
4168 # calculate hvvp (2nd order directional derivative using the JVP)
4269 hvvp = similar (res)
4370 mul! (hvvp, H, a)
4471
45- solve! (solver, J, hvvp) # J \ hvvp
72+ solve! (solver, J, hvvp; kwargs ... ) # J \ hvvp
4673 b = solver. x
4774
48- # update
75+ # Update u
4976 @. u -= (a * a) / (a - b / 2 )
5077
51- end
52- end
53-
54- # u = [2.0, 0.5]
55- # res = zeros(2)
56- # J = NewtonKrylov.JacobianOperator(F!,u,res)
57- # F!(res, u)
58- # a, stats = gmres(J, copy(res))
78+ # Update residual and norm
79+ n_res_prior = n_res
5980
60- # J_cache = Enzyme.make_zero(J)
61- # out = similar(J.res)
62- # hvvp = Enzyme.make_zero(out)
63- # du = Enzyme.make_zero(J.u)
64- # autodiff(Forward, LinearAlgebra.mul!,
65- # DuplicatedNoNeed(out, hvvp),
66- # DuplicatedNoNeed(J, J_cache),
67- # DuplicatedNoNeed(du, a))
81+ F! (res, u) # res = F(u)
82+ n_res = norm (res)
83+ callback (u, res, n_res)
6884
69- # hvvp
70-
71- # b, stats = gmres(J, hvvp)
72- # @. u -= (a * a) / (a - b / 2)
73-
74- # a
85+ if isinf (n_res) || isnan (n_res)
86+ @error " Inner solver blew up" stats
87+ break
88+ end
7589
90+ if forcing != = nothing
91+ η = forcing (η, tol, n_res, n_res_prior)
92+ end
7693
77- dg_ad (x, dx) = autodiff (Forward, flux, DuplicatedNoNeed (x, dx))[1 ]
78- ddg_ad (x, dx, ddx) = autodiff (Forward, dg_ad, DuplicatedNoNeed (x, dx),
79- DuplicatedNoNeed (dx, ddx))[1 ]
94+ verbose > 0 && @info " Newton" iter = n_res η= (forcing != = nothing ? η : nothing ) stats
95+ stats = update (stats, solver. stats. niter) # TODO we do two calls to solver iterations
96+ end
97+ t = (time_ns () - t₀) / 1.0e9
98+ return u, (; solved = n_res <= tol, stats, t)
99+ end
80100
81101xs = LinRange (- 3 , 8 , 1000 )
82102ys = LinRange (- 15 , 10 , 1000 )
@@ -93,21 +113,15 @@ trace_1 = let x₀ = [2.0, 0.5]
93113end
94114lines! (ax, trace_1)
95115
96- trace_2 = let x₀ = [2.5 , 3.0 ]
116+ trace_2 = let x₀ = [2.0 , 0.5 ]
97117 xs = Vector {Tuple{Float64, Float64}} (undef, 0 )
98118 hist (x, res, n_res) = (push! (xs, (x[1 ], x[2 ])); nothing )
99- x, stats = newton_krylov! (F!, x₀, callback = hist)
119+ x, stats = halley_krylov! (F!, x₀, similar (x₀), callback = hist, verbose= 1 , forcing= nothing )
120+ @show stats
100121 xs
101122end
102123lines! (ax, trace_2)
103124
104- trace_3 = let x₀ = [3.0 , 4.0 ]
105- xs = Vector {Tuple{Float64, Float64}} (undef, 0 )
106- hist (x, res, n_res) = (push! (xs, (x[1 ], x[2 ])); nothing )
107- x, stats = newton_krylov! (F!, x₀, callback = hist, forcing = NewtonKrylov. EisenstatWalker (η_max = 0.68949 ), verbose = 1 )
108- @show stats. solved
109- xs
110- end
111- lines! (ax, trace_3)
125+ trace_2
112126
113127fig
0 commit comments