Skip to content

Commit ca52afb

Browse files
committed
move halley into directory
1 parent e8a3d04 commit ca52afb

File tree

3 files changed

+162
-99
lines changed

3 files changed

+162
-99
lines changed

examples/halley.jl

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,90 +14,6 @@ function F(x)
1414
return res
1515
end
1616

17-
import NewtonKrylov: Forcing, EisenstatWalker, inital, forcing, solve!,
18-
JacobianOperator, HessianOperator, Stats, update, GmresSolver
19-
20-
function halley_krylov!(
21-
F!, u::AbstractArray, res::AbstractArray;
22-
tol_rel = 1.0e-6,
23-
tol_abs = 1.0e-12,
24-
max_niter = 50,
25-
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
26-
verbose = 0,
27-
Solver = GmresSolver,
28-
M = nothing,
29-
N = nothing,
30-
krylov_kwargs = (;),
31-
callback = (args...) -> nothing,
32-
)
33-
t₀ = time_ns()
34-
F!(res, u) # res = F(u)
35-
n_res = norm(res)
36-
callback(u, res, n_res)
37-
38-
tol = tol_rel * n_res + tol_abs
39-
40-
if forcing !== nothing
41-
η = inital(forcing)
42-
end
43-
44-
verbose > 0 && @info "Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
45-
46-
J = JacobianOperator(F!, res, u)
47-
H = HessianOperator(J)
48-
solver = Solver(J, res)
49-
50-
stats = Stats(0, 0)
51-
while n_res > tol && stats.outer_iterations <= max_niter
52-
# Handle kwargs for Preconditoners
53-
kwargs = krylov_kwargs
54-
if N !== nothing
55-
kwargs = (; N = N(J), kwargs...)
56-
end
57-
if M !== nothing
58-
kwargs = (; M = M(J), kwargs...)
59-
end
60-
if forcing !== nothing
61-
# ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
62-
kwargs = (; rtol = η, kwargs...)
63-
end
64-
65-
solve!(solver, J, copy(res); kwargs...) # J \ fx
66-
a = copy(solver.x)
67-
68-
# calculate hvvp (2nd order directional derivative using the JVP)
69-
hvvp = similar(res)
70-
mul!(hvvp, H, a)
71-
72-
solve!(solver, J, hvvp; kwargs...) # J \ hvvp
73-
b = solver.x
74-
75-
# Update u
76-
@. u -= (a * a) / (a - b / 2)
77-
78-
# Update residual and norm
79-
n_res_prior = n_res
80-
81-
F!(res, u) # res = F(u)
82-
n_res = norm(res)
83-
callback(u, res, n_res)
84-
85-
if isinf(n_res) || isnan(n_res)
86-
@error "Inner solver blew up" stats
87-
break
88-
end
89-
90-
if forcing !== nothing
91-
η = forcing(η, tol, n_res, n_res_prior)
92-
end
93-
94-
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
95-
stats = update(stats, solver.stats.niter) # TODO we do two calls to solver iterations
96-
end
97-
t = (time_ns() - t₀) / 1.0e9
98-
return u, (; solved = n_res <= tol, stats, t)
99-
end
100-
10117
xs = LinRange(-3, 8, 1000)
10218
ys = LinRange(-15, 10, 1000)
10319

@@ -116,12 +32,10 @@ lines!(ax, trace_1)
11632
trace_2 = let x₀ = [2.0, 0.5]
11733
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
11834
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
119-
x, stats = halley_krylov!(F!, x₀, similar(x₀), callback = hist, verbose=1, forcing=nothing)
35+
x, stats = halley_krylov!(F!, x₀, similar(x₀), callback = hist, verbose = 1, forcing = nothing)
12036
@show stats
12137
xs
12238
end
12339
lines!(ax, trace_2)
12440

125-
trace_2
126-
12741
fig

src/NewtonKrylov.jl

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module NewtonKrylov
22

33
export newton_krylov, newton_krylov!
4+
export halley_krylov, halley_krylov!
45

56
using Krylov
67
using LinearAlgebra, SparseArrays
@@ -84,6 +85,11 @@ function Base.collect(JOp::JacobianOperator)
8485
return J
8586
end
8687

88+
"""
89+
HessianOperator
90+
91+
Calculcates H(F, u) * v * v
92+
"""
8793
struct HessianOperator{F, A}
8894
J::JacobianOperator{F, A}
8995
J_cache::JacobianOperator{F, A}
@@ -95,12 +101,15 @@ Base.eltype(H::HessianOperator) = eltype(H.J)
95101

96102
function mul!(out, H::HessianOperator, v)
97103
_out = similar(H.J.res) # TODO cache in H
98-
du = Enzyme.make_zero(H.J.u) # TODO cache in H
99-
100-
autodiff(Forward, mul!,
101-
DuplicatedNoNeed(_out, out),
102-
DuplicatedNoNeed(H.J, H.J_cache),
103-
DuplicatedNoNeed(du, v))
104+
Enzyme.make_zero!(H.J_cache)
105+
H.J_cache.u .= v
106+
autodiff(
107+
Forward,
108+
mul!,
109+
DuplicatedNoNeed(_out, out),
110+
DuplicatedNoNeed(H.J, H.J_cache),
111+
Const(v)
112+
)
104113

105114
return nothing
106115
end
@@ -247,11 +256,102 @@ function newton_krylov!(
247256
η = forcing(η, tol, n_res, n_res_prior)
248257
end
249258

250-
verbose > 0 && @info "Newton" iter = n_res η=(forcing !== nothing ? η : nothing) stats
259+
verbose > 0 && @info "Newton" iter = n_res η = (forcing !== nothing ? η : nothing) stats
251260
stats = update(stats, solver.stats.niter)
252261
end
253262
t = (time_ns() - t₀) / 1.0e9
254263
return u, (; solved = n_res <= tol, stats, t)
255264
end
256265

266+
function halley_krylov(F, u₀::AbstractArray, M::Int = length(u₀); kwargs...)
267+
F!(res, u) = (res .= F(u); nothing)
268+
return halley_krylov!(F!, u₀, M; kwargs...)
269+
end
270+
271+
function halley_krylov!(F!, u₀::AbstractArray, M::Int = length(u₀); kwargs...)
272+
res = similar(u₀, M)
273+
return halley_krylov!(F!, u₀, res; kwargs...)
274+
end
275+
276+
function halley_krylov!(
277+
F!, u::AbstractArray, res::AbstractArray;
278+
tol_rel = 1.0e-6,
279+
tol_abs = 1.0e-12,
280+
max_niter = 50,
281+
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
282+
verbose = 0,
283+
Solver = GmresSolver,
284+
M = nothing,
285+
N = nothing,
286+
krylov_kwargs = (;),
287+
callback = (args...) -> nothing,
288+
)
289+
t₀ = time_ns()
290+
F!(res, u) # res = F(u)
291+
n_res = norm(res)
292+
callback(u, res, n_res)
293+
294+
tol = tol_rel * n_res + tol_abs
295+
296+
if forcing !== nothing
297+
η = inital(forcing)
298+
end
299+
300+
verbose > 0 && @info "Jacobian-Free Halley-Krylov" Solver res₀ = n_res tol tol_rel tol_abs η
301+
302+
J = JacobianOperator(F!, res, u)
303+
H = HessianOperator(J)
304+
solver = Solver(J, res)
305+
306+
stats = Stats(0, 0)
307+
while n_res > tol && stats.outer_iterations <= max_niter
308+
# Handle kwargs for Preconditoners
309+
kwargs = krylov_kwargs
310+
if N !== nothing
311+
kwargs = (; N = N(J), kwargs...)
312+
end
313+
if M !== nothing
314+
kwargs = (; M = M(J), kwargs...)
315+
end
316+
if forcing !== nothing
317+
# ‖F′(u)d + F(u)‖ <= η * ‖F(u)‖ Inexact Newton termination
318+
kwargs = (; rtol = η, kwargs...)
319+
end
320+
321+
solve!(solver, J, copy(res); kwargs...) # J \ fx
322+
a = copy(solver.x)
323+
324+
# calculate hvvp (2nd order directional derivative using the JVP)
325+
hvvp = similar(res)
326+
mul!(hvvp, H, a)
327+
328+
solve!(solver, J, hvvp; kwargs...) # J \ hvvp
329+
b = solver.x
330+
331+
# Update u
332+
@. u -= (a * a) / (a - b / 2)
333+
334+
# Update residual and norm
335+
n_res_prior = n_res
336+
337+
F!(res, u) # res = F(u)
338+
n_res = norm(res)
339+
callback(u, res, n_res)
340+
341+
if isinf(n_res) || isnan(n_res)
342+
@error "Inner solver blew up" stats
343+
break
344+
end
345+
346+
if forcing !== nothing
347+
η = forcing(η, tol, n_res, n_res_prior)
348+
end
349+
350+
verbose > 0 && @info "Newton" iter = n_res η = (forcing !== nothing ? η : nothing) stats
351+
stats = update(stats, solver.stats.niter) # TODO we do two calls to solver iterations
352+
end
353+
t = (time_ns() - t₀) / 1.0e9
354+
return u, (; solved = n_res <= tol, stats, t)
355+
end
356+
257357
end # module NewtonKrylov

test/runtests.jl

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,65 @@ end
2525
import NewtonKrylov: JacobianOperator
2626
using Enzyme, LinearAlgebra
2727

28+
function df(x, a)
29+
return autodiff(Forward, F, DuplicatedNoNeed(x, a)) |> first
30+
end
31+
32+
function df!(out, x, a)
33+
res = similar(out)
34+
autodiff(Forward, F!, DuplicatedNoNeed(res, out), DuplicatedNoNeed(x, a))
35+
return nothing
36+
end
37+
2838
@testset "Jacobian" begin
29-
J_Enz = jacobian(Forward, F, [3.0, 5.0]) |> only
30-
J = JacobianOperator(F!, zeros(2), [3.0, 5.0])
39+
x = [3.0, 5.0]
40+
v = rand(2)
41+
42+
J_Enz = jacobian(Forward, F, x) |> only
43+
J = JacobianOperator(F!, zeros(2), x)
3144
J_NK = collect(J)
3245

3346
@test J_NK == J_Enz
3447

48+
jvp = similar(v)
49+
mul!(jvp, J, v)
50+
51+
jvp2 = df(x, v)
52+
@test jvp == jvp2
53+
54+
jvp3 = similar(v)
55+
df!(jvp3, x, v)
56+
@test jvp == jvp3
57+
58+
@test jvp J_Enz * v
59+
end
60+
61+
# Differentiate F with respect to x twice.
62+
function ddf(x, a)
63+
return autodiff(Forward, df, DuplicatedNoNeed(x, a), Const(a)) |> first
64+
end
65+
66+
function ddf!(out, x, a)
67+
_out = similar(out)
68+
autodiff(Forward, df!, DuplicatedNoNeed(_out, out), DuplicatedNoNeed(x, a), Const(a))
69+
return nothing
70+
end
71+
72+
@testset "2nd order directional derivative" begin
73+
x = [3.0, 5.0]
3574
v = rand(2)
36-
out = similar(v)
37-
mul!(out, J, v)
3875

39-
@test out J_Enz * v
76+
hvvp = similar(x)
77+
ddf!(hvvp, x, v)
78+
79+
hvvp2 = ddf(x, v)
80+
@test hvvp == hvvp2
81+
82+
J = JacobianOperator(F!, zeros(2), x)
83+
H = HessianOperator(J)
84+
85+
hvvp3 = similar(x)
86+
mul!(hvvp3, H, v)
87+
88+
@test hvvp == hvvp3
4089
end

0 commit comments

Comments
 (0)