Skip to content

Commit 83c0723

Browse files
committed
auto switch to finitediff for inplace problems
1 parent 7e26d18 commit 83c0723

File tree

3 files changed

+76
-52
lines changed

3 files changed

+76
-52
lines changed

src/linesearch.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,15 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip
9191

9292
g₀ = _mutable_zero(u)
9393

94+
autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
95+
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling back to finite differencing."
96+
AutoFiniteDiff()
97+
else
98+
ls.autodiff
99+
end
100+
94101
function g!(u, fu)
95-
op = VecJac((args...) -> f(args..., p), u)
102+
op = VecJac((args...) -> f(args..., p), u; autodiff)
96103
if iip
97104
mul!(g₀, op, fu)
98105
return g₀
@@ -134,13 +141,16 @@ function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip
134141
end
135142

136143
function perform_linesearch!(cache::LineSearchCache, u, du)
137-
cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α))
144+
cache.ls.method isa Static && return cache.α
138145

139146
ϕ = cache.ϕ(u, du)
140147
= cache.(u, du)
141148
ϕdϕ = cache.ϕdϕ(u, du)
142149

143150
ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))
144151

145-
return cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀)
152+
# This case is sometimes possible for large optimization problems
153+
dϕ₀ 0 && return cache.α
154+
155+
return first(cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀))
146156
end

src/raphson.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ function perform_step!(cache::NewtonRaphsonCache{true})
9393
cache.linsolve = linres.cache
9494

9595
# Line Search
96-
α, _ = perform_linesearch!(cache.lscache, u, du)
96+
α = perform_linesearch!(cache.lscache, u, du)
9797
@. u = u - α * du
98+
f(cache.fu1, u, p)
9899

99100
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
100101
cache.stats.nf += 1
@@ -118,7 +119,7 @@ function perform_step!(cache::NewtonRaphsonCache{false})
118119
end
119120

120121
# Line Search
121-
α, _fu = perform_linesearch!(cache.lscache, u, cache.du)
122+
α = perform_linesearch!(cache.lscache, u, cache.du)
122123
cache.u = @. u - α * cache.du # `u` might not support mutation
123124
cache.fu1 = f(cache.u, p)
124125

test/basictests.jl

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ end
5353

5454
@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
5555
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
56+
ad isa AutoZygote && continue
5657
if prec === :Random
5758
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
5859
end
59-
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, linesearch)
60+
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec,
61+
linesearch)
6062
@test SciMLBase.successful_retcode(sol)
6163
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
6264

@@ -67,25 +69,30 @@ end
6769
end
6870

6971
if VERSION v"1.9"
70-
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
71-
@test begin
72-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
73-
res_true = sqrt(p)
74-
all(res.u .≈ res_true)
72+
@testset "[OOP] [Immutable AD]" begin
73+
for p in 1.0:0.1:100.0
74+
@test begin
75+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
76+
res_true = sqrt(p)
77+
all(res.u .≈ res_true)
78+
end
79+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
80+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
7581
end
76-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
77-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
7882
end
7983
end
8084

81-
@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
82-
@test begin
83-
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
84-
res_true = sqrt(p)
85-
res.u res_true
85+
@testset "[OOP] [Scalar AD]" begin
86+
for p in 1.0:0.1:100.0
87+
@test begin
88+
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
89+
res_true = sqrt(p)
90+
res.u res_true
91+
end
92+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
93+
p)
94+
1 / (2 * sqrt(p))
8695
end
87-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)
88-
1 / (2 * sqrt(p))
8996
end
9097

9198
if VERSION v"1.9"
@@ -162,33 +169,34 @@ end
162169
end
163170

164171
if VERSION v"1.9"
165-
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
166-
p in 1.0:0.1:100.0
172+
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
173+
for p in 1.0:0.1:100.0
174+
@test begin
175+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
176+
radius_update_scheme)
177+
res_true = sqrt(p)
178+
all(res.u .≈ res_true)
179+
end
180+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
181+
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
182+
end
183+
end
184+
end
167185

186+
@testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
187+
for p in 1.0:0.1:100.0
168188
@test begin
169-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
189+
res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p;
170190
radius_update_scheme)
171191
res_true = sqrt(p)
172-
all(res.u . res_true)
192+
res.u res_true
173193
end
174194
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
175-
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
195+
oftype(p, 1.0),
196+
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
176197
end
177198
end
178199

179-
@testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
180-
p in 1.0:0.1:100.0
181-
182-
@test begin
183-
res = benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0), p;
184-
radius_update_scheme)
185-
res_true = sqrt(p)
186-
res.u res_true
187-
end
188-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, oftype(p, 1.0),
189-
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
190-
end
191-
192200
if VERSION v"1.9"
193201
t = (p) -> [sqrt(p[2] / p[1])]
194202
p = [0.9, 50.0]
@@ -316,25 +324,30 @@ end
316324
end
317325

318326
if VERSION v"1.9"
319-
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
320-
@test begin
321-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
322-
res_true = sqrt(p)
323-
all(res.u .≈ res_true)
327+
@testset "[OOP] [Immutable AD]" begin
328+
for p in 1.0:0.1:100.0
329+
@test begin
330+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
331+
res_true = sqrt(p)
332+
all(res.u .≈ res_true)
333+
end
334+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
335+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
324336
end
325-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
326-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
327337
end
328338
end
329339

330-
@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
331-
@test begin
332-
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
333-
res_true = sqrt(p)
334-
res.u res_true
340+
@testset "[OOP] [Scalar AD]" begin
341+
for p in 1.0:0.1:100.0
342+
@test begin
343+
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
344+
res_true = sqrt(p)
345+
res.u res_true
346+
end
347+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
348+
p)
349+
1 / (2 * sqrt(p))
335350
end
336-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)
337-
1 / (2 * sqrt(p))
338351
end
339352

340353
if VERSION v"1.9"

0 commit comments

Comments
 (0)