Skip to content

Commit c546355

Browse files
committed
mirror newton_krylov
1 parent ca593cc commit c546355

File tree

2 files changed

+65
-51
lines changed

2 files changed

+65
-51
lines changed

examples/halley.jl

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
using NewtonKrylov, LinearAlgebra
44
using CairoMakie
5-
using Krylov, Enzyme
65

76
function F!(res, x)
87
res[1] = x[1]^2 + x[2]^2 - 2
@@ -15,68 +14,89 @@ function F(x)
1514
return res
1615
end
1716

18-
function halley(F!, u, res;
17+
import NewtonKrylov: Forcing, EisenstatWalker, inital, forcing, solve!,
18+
JacobianOperator, HessianOperator, Stats, update, GmresSolver
19+
20+
function halley_krylov!(
21+
F!, u::AbstractArray, res::AbstractArray;
1922
tol_rel = 1.0e-6,
2023
tol_abs = 1.0e-12,
2124
max_niter = 50,
25+
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
26+
verbose = 0,
2227
Solver = GmresSolver,
28+
M = nothing,
29+
N = nothing,
30+
krylov_kwargs = (;),
31+
callback = (args...) -> nothing,
2332
)
24-
33+
t₀ = time_ns()
2534
F!(res, u) # res = F(u)
2635
n_res = norm(res)
36+
callback(u, res, n_res)
2737

2838
tol = tol_rel * n_res + tol_abs
2939

30-
J = NewtonKrylov.JacobianOperator(F!, res, u)
31-
H = NewtonKrylov.HessianOperator(J)
40+
if forcing !== nothing
41+
η = inital(forcing)
42+
end
43+
44+
verbose > 0 && @info "Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
45+
46+
J = JacobianOperator(F!, res, u)
47+
H = HessianOperator(J)
3248
solver = Solver(J, res)
3349

34-
for i in :max_niter
35-
if n_res <= tol
36-
break
50+
stats = Stats(0, 0)
51+
while n_res > tol && stats.outer_iterations <= max_niter
52+
# Handle kwargs for Preconditoners
53+
kwargs = krylov_kwargs
54+
if N !== nothing
55+
kwargs = (; N = N(J), kwargs...)
3756
end
38-
solve!(solver, J, copy(res)) # J \ fx
57+
if M !== nothing
58+
kwargs = (; M = M(J), kwargs...)
59+
end
60+
if forcing !== nothing
61+
# ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
62+
kwargs = (; rtol = η, kwargs...)
63+
end
64+
65+
solve!(solver, J, copy(res); kwargs...) # J \ fx
3966
a = copy(solver.x)
4067

4168
# calculate hvvp (2nd order directional derivative using the JVP)
4269
hvvp = similar(res)
4370
mul!(hvvp, H, a)
4471

45-
solve!(solver, J, hvvp) # J \ hvvp
72+
solve!(solver, J, hvvp; kwargs...) # J \ hvvp
4673
b = solver.x
4774

48-
# update
75+
# Update u
4976
@. u -= (a * a) / (a - b / 2)
5077

51-
end
52-
end
53-
54-
# u = [2.0, 0.5]
55-
# res = zeros(2)
56-
# J = NewtonKrylov.JacobianOperator(F!,u,res)
57-
# F!(res, u)
58-
# a, stats = gmres(J, copy(res))
78+
# Update residual and norm
79+
n_res_prior = n_res
5980

60-
# J_cache = Enzyme.make_zero(J)
61-
# out = similar(J.res)
62-
# hvvp = Enzyme.make_zero(out)
63-
# du = Enzyme.make_zero(J.u)
64-
# autodiff(Forward, LinearAlgebra.mul!,
65-
# DuplicatedNoNeed(out, hvvp),
66-
# DuplicatedNoNeed(J, J_cache),
67-
# DuplicatedNoNeed(du, a))
81+
F!(res, u) # res = F(u)
82+
n_res = norm(res)
83+
callback(u, res, n_res)
6884

69-
# hvvp
70-
71-
# b, stats = gmres(J, hvvp)
72-
# @. u -= (a * a) / (a - b / 2)
73-
74-
# a
85+
if isinf(n_res) || isnan(n_res)
86+
@error "Inner solver blew up" stats
87+
break
88+
end
7589

90+
if forcing !== nothing
91+
η = forcing(η, tol, n_res, n_res_prior)
92+
end
7693

77-
dg_ad(x, dx) = autodiff(Forward, flux, DuplicatedNoNeed(x, dx))[1]
78-
ddg_ad(x, dx, ddx) = autodiff(Forward, dg_ad, DuplicatedNoNeed(x, dx),
79-
DuplicatedNoNeed(dx, ddx))[1]
94+
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
95+
stats = update(stats, solver.stats.niter) # TODO we do two calls to solver iterations
96+
end
97+
t = (time_ns() - t₀) / 1.0e9
98+
return u, (; solved = n_res <= tol, stats, t)
99+
end
80100

81101
xs = LinRange(-3, 8, 1000)
82102
ys = LinRange(-15, 10, 1000)
@@ -93,21 +113,15 @@ trace_1 = let x₀ = [2.0, 0.5]
93113
end
94114
lines!(ax, trace_1)
95115

96-
trace_2 = let x₀ = [2.5, 3.0]
116+
trace_2 = let x₀ = [2.0, 0.5]
97117
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
98118
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
99-
x, stats = newton_krylov!(F!, x₀, callback = hist)
119+
x, stats = halley_krylov!(F!, x₀, similar(x₀), callback = hist, verbose=1, forcing=nothing)
120+
@show stats
100121
xs
101122
end
102123
lines!(ax, trace_2)
103124

104-
trace_3 = let x₀ = [3.0, 4.0]
105-
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
106-
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
107-
x, stats = newton_krylov!(F!, x₀, callback = hist, forcing = NewtonKrylov.EisenstatWalker(η_max = 0.68949), verbose = 1)
108-
@show stats.solved
109-
xs
110-
end
111-
lines!(ax, trace_3)
125+
trace_2
112126

113127
fig

src/NewtonKrylov.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,15 @@ function newton_krylov!(
273273
end
274274

275275
# Solve: Jx = -res
276-
# res is modifyed by J, so we create a copy `-res`
277-
# TODO: provide a temporary storage for `-res`
278-
solve!(solver, J, -res; kwargs...)
276+
# res is modifyed by J, so we create a copy `res`
277+
# TODO: provide a temporary storage for `res`
278+
solve!(solver, J, copy(res); kwargs...)
279279

280280
d = solver.x # Newton direction
281281
s = 1 # Newton step TODO: LineSearch
282282

283283
# Update u
284-
u .+= s .* d
284+
@. u -= s * d
285285

286286
# Update residual and norm
287287
n_res_prior = n_res
@@ -299,7 +299,7 @@ function newton_krylov!(
299299
η = forcing(η, tol, n_res, n_res_prior)
300300
end
301301

302-
verbose > 0 && @info "Newton" iter = n_res η stats
302+
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
303303
stats = update(stats, solver.stats.niter)
304304
end
305305
t = (time_ns() - t₀) / 1.0e9

0 commit comments

Comments
 (0)