Skip to content

Commit 2e3c3f1

Browse files
committed
move halley into directory
1 parent c546355 commit 2e3c3f1

File tree

3 files changed

+162
-99
lines changed

3 files changed

+162
-99
lines changed

examples/halley.jl

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,90 +14,6 @@ function F(x)
1414
return res
1515
end
1616

17-
import NewtonKrylov: Forcing, EisenstatWalker, inital, forcing, solve!,
18-
JacobianOperator, HessianOperator, Stats, update, GmresSolver
19-
20-
function halley_krylov!(
21-
F!, u::AbstractArray, res::AbstractArray;
22-
tol_rel = 1.0e-6,
23-
tol_abs = 1.0e-12,
24-
max_niter = 50,
25-
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
26-
verbose = 0,
27-
Solver = GmresSolver,
28-
M = nothing,
29-
N = nothing,
30-
krylov_kwargs = (;),
31-
callback = (args...) -> nothing,
32-
)
33-
t₀ = time_ns()
34-
F!(res, u) # res = F(u)
35-
n_res = norm(res)
36-
callback(u, res, n_res)
37-
38-
tol = tol_rel * n_res + tol_abs
39-
40-
if forcing !== nothing
41-
η = inital(forcing)
42-
end
43-
44-
verbose > 0 && @info "Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
45-
46-
J = JacobianOperator(F!, res, u)
47-
H = HessianOperator(J)
48-
solver = Solver(J, res)
49-
50-
stats = Stats(0, 0)
51-
while n_res > tol && stats.outer_iterations <= max_niter
52-
# Handle kwargs for Preconditoners
53-
kwargs = krylov_kwargs
54-
if N !== nothing
55-
kwargs = (; N = N(J), kwargs...)
56-
end
57-
if M !== nothing
58-
kwargs = (; M = M(J), kwargs...)
59-
end
60-
if forcing !== nothing
61-
# ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
62-
kwargs = (; rtol = η, kwargs...)
63-
end
64-
65-
solve!(solver, J, copy(res); kwargs...) # J \ fx
66-
a = copy(solver.x)
67-
68-
# calculate hvvp (2nd order directional derivative using the JVP)
69-
hvvp = similar(res)
70-
mul!(hvvp, H, a)
71-
72-
solve!(solver, J, hvvp; kwargs...) # J \ hvvp
73-
b = solver.x
74-
75-
# Update u
76-
@. u -= (a * a) / (a - b / 2)
77-
78-
# Update residual and norm
79-
n_res_prior = n_res
80-
81-
F!(res, u) # res = F(u)
82-
n_res = norm(res)
83-
callback(u, res, n_res)
84-
85-
if isinf(n_res) || isnan(n_res)
86-
@error "Inner solver blew up" stats
87-
break
88-
end
89-
90-
if forcing !== nothing
91-
η = forcing(η, tol, n_res, n_res_prior)
92-
end
93-
94-
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
95-
stats = update(stats, solver.stats.niter) # TODO we do two calls to solver iterations
96-
end
97-
t = (time_ns() - t₀) / 1.0e9
98-
return u, (; solved = n_res <= tol, stats, t)
99-
end
100-
10117
xs = LinRange(-3, 8, 1000)
10218
ys = LinRange(-15, 10, 1000)
10319

@@ -116,12 +32,10 @@ lines!(ax, trace_1)
11632
trace_2 = let x₀ = [2.0, 0.5]
11733
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
11834
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
119-
x, stats = halley_krylov!(F!, x₀, similar(x₀), callback = hist, verbose=1, forcing=nothing)
35+
x, stats = halley_krylov!(F!, x₀, similar(x₀), callback = hist, verbose = 1, forcing = nothing)
12036
@show stats
12137
xs
12238
end
12339
lines!(ax, trace_2)
12440

125-
trace_2
126-
12741
fig

src/NewtonKrylov.jl

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NewtonKrylov
22

33
export newton_krylov, newton_krylov!
4+
export halley_krylov, halley_krylov!
45

56
using Krylov
67
using LinearAlgebra, SparseArrays
@@ -89,6 +90,11 @@ function Base.collect(JOp::JacobianOperator)
8990
return J
9091
end
9192

93+
"""
94+
HessianOperator
95+
96+
Calculcates H(F, u) * v * v
97+
"""
9298
struct HessianOperator{F, A}
9399
J::JacobianOperator{F, A}
94100
J_cache::JacobianOperator{F, A}
@@ -100,12 +106,15 @@ Base.eltype(H::HessianOperator) = eltype(H.J)
100106

101107
function mul!(out, H::HessianOperator, v)
102108
_out = similar(H.J.res) # TODO cache in H
103-
du = Enzyme.make_zero(H.J.u) # TODO cache in H
104-
105-
autodiff(Forward, mul!,
106-
DuplicatedNoNeed(_out, out),
107-
DuplicatedNoNeed(H.J, H.J_cache),
108-
DuplicatedNoNeed(du, v))
109+
Enzyme.make_zero!(H.J_cache)
110+
H.J_cache.u .= v
111+
autodiff(
112+
Forward,
113+
mul!,
114+
DuplicatedNoNeed(_out, out),
115+
DuplicatedNoNeed(H.J, H.J_cache),
116+
Const(v)
117+
)
109118

110119
return nothing
111120
end
@@ -299,11 +308,102 @@ function newton_krylov!(
299308
η = forcing(η, tol, n_res, n_res_prior)
300309
end
301310

302-
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
311+
verbose > 0 && @info "Newton" iter = n_res η = (forcing !== nothing ? η : nothing) stats
303312
stats = update(stats, solver.stats.niter)
304313
end
305314
t = (time_ns() - t₀) / 1.0e9
306315
return u, (; solved = n_res <= tol, stats, t)
307316
end
308317

318+
function halley_krylov(F, u₀::AbstractArray, M::Int = length(u₀); kwargs...)
319+
F!(res, u) = (res .= F(u); nothing)
320+
return halley_krylov!(F!, u₀, M; kwargs...)
321+
end
322+
323+
function halley_krylov!(F!, u₀::AbstractArray, M::Int = length(u₀); kwargs...)
324+
res = similar(u₀, M)
325+
return halley_krylov!(F!, u₀, res; kwargs...)
326+
end
327+
328+
function halley_krylov!(
329+
F!, u::AbstractArray, res::AbstractArray;
330+
tol_rel = 1.0e-6,
331+
tol_abs = 1.0e-12,
332+
max_niter = 50,
333+
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
334+
verbose = 0,
335+
Solver = GmresSolver,
336+
M = nothing,
337+
N = nothing,
338+
krylov_kwargs = (;),
339+
callback = (args...) -> nothing,
340+
)
341+
t₀ = time_ns()
342+
F!(res, u) # res = F(u)
343+
n_res = norm(res)
344+
callback(u, res, n_res)
345+
346+
tol = tol_rel * n_res + tol_abs
347+
348+
if forcing !== nothing
349+
η = inital(forcing)
350+
end
351+
352+
verbose > 0 && @info "Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
353+
354+
J = JacobianOperator(F!, res, u)
355+
H = HessianOperator(J)
356+
solver = Solver(J, res)
357+
358+
stats = Stats(0, 0)
359+
while n_res > tol && stats.outer_iterations <= max_niter
360+
# Handle kwargs for Preconditoners
361+
kwargs = krylov_kwargs
362+
if N !== nothing
363+
kwargs = (; N = N(J), kwargs...)
364+
end
365+
if M !== nothing
366+
kwargs = (; M = M(J), kwargs...)
367+
end
368+
if forcing !== nothing
369+
# ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
370+
kwargs = (; rtol = η, kwargs...)
371+
end
372+
373+
solve!(solver, J, copy(res); kwargs...) # J \ fx
374+
a = copy(solver.x)
375+
376+
# calculate hvvp (2nd order directional derivative using the JVP)
377+
hvvp = similar(res)
378+
mul!(hvvp, H, a)
379+
380+
solve!(solver, J, hvvp; kwargs...) # J \ hvvp
381+
b = solver.x
382+
383+
# Update u
384+
@. u -= (a * a) / (a - b / 2)
385+
386+
# Update residual and norm
387+
n_res_prior = n_res
388+
389+
F!(res, u) # res = F(u)
390+
n_res = norm(res)
391+
callback(u, res, n_res)
392+
393+
if isinf(n_res) || isnan(n_res)
394+
@error "Inner solver blew up" stats
395+
break
396+
end
397+
398+
if forcing !== nothing
399+
η = forcing(η, tol, n_res, n_res_prior)
400+
end
401+
402+
verbose > 0 && @info "Newton" iter = n_res η = (forcing !== nothing ? η : nothing) stats
403+
stats = update(stats, solver.stats.niter) # TODO we do two calls to solver iterations
404+
end
405+
t = (time_ns() - t₀) / 1.0e9
406+
return u, (; solved = n_res <= tol, stats, t)
407+
end
408+
309409
end # module NewtonKrylov

test/runtests.jl

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,65 @@ end
2525
import NewtonKrylov: JacobianOperator
2626
using Enzyme, LinearAlgebra
2727

28+
function df(x, a)
29+
return autodiff(Forward, F, DuplicatedNoNeed(x, a)) |> first
30+
end
31+
32+
function df!(out, x, a)
33+
res = similar(out)
34+
autodiff(Forward, F!, DuplicatedNoNeed(res, out), DuplicatedNoNeed(x, a))
35+
return nothing
36+
end
37+
2838
@testset "Jacobian" begin
29-
J_Enz = jacobian(Forward, F, [3.0, 5.0]) |> only
30-
J = JacobianOperator(F!, zeros(2), [3.0, 5.0])
39+
x = [3.0, 5.0]
40+
v = rand(2)
41+
42+
J_Enz = jacobian(Forward, F, x) |> only
43+
J = JacobianOperator(F!, zeros(2), x)
3144
J_NK = collect(J)
3245

3346
@test J_NK == J_Enz
3447

48+
jvp = similar(v)
49+
mul!(jvp, J, v)
50+
51+
jvp2 = df(x, v)
52+
@test jvp == jvp2
53+
54+
jvp3 = similar(v)
55+
df!(jvp3, x, v)
56+
@test jvp == jvp3
57+
58+
@test jvp J_Enz * v
59+
end
60+
61+
# Differentiate F with respect to x twice.
62+
function ddf(x, a)
63+
return autodiff(Forward, df, DuplicatedNoNeed(x, a), Const(a)) |> first
64+
end
65+
66+
function ddf!(out, x, a)
67+
_out = similar(out)
68+
autodiff(Forward, df!, DuplicatedNoNeed(_out, out), DuplicatedNoNeed(x, a), Const(a))
69+
return nothing
70+
end
71+
72+
@testset "2nd order directional derivative" begin
73+
x = [3.0, 5.0]
3574
v = rand(2)
36-
out = similar(v)
37-
mul!(out, J, v)
3875

39-
@test out J_Enz * v
76+
hvvp = similar(x)
77+
ddf!(hvvp, x, v)
78+
79+
hvvp2 = ddf(x, v)
80+
@test hvvp == hvvp2
81+
82+
J = JacobianOperator(F!, zeros(2), x)
83+
H = HessianOperator(J)
84+
85+
hvvp3 = similar(x)
86+
mul!(hvvp3, H, v)
87+
88+
@test hvvp == hvvp3
4089
end

0 commit comments

Comments
 (0)