Skip to content

Commit e73bb5b

Browse files
author
Oscar Smith
authored
fix type stability of SimpleNonlinearSolve (#536)
* fix type stability of SimpleNonlinearSolve * fix simplehalley
1 parent 72d654d commit e73bb5b

File tree

5 files changed

+43
-47
lines changed

5 files changed

+43
-47
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
3939
} where {iip, T, V, P}
4040

4141
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end
42+
configure_autodiff(prob, alg::AbstractSimpleNonlinearSolveAlgorithm) = alg
4243

4344
const NLBUtils = NonlinearSolveBase.Utils
4445

@@ -59,12 +60,6 @@ function CommonSolve.solve(
5960
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
6061
kwargs...
6162
)
62-
cache = SciMLBase.__init(prob, alg, args...; kwargs...)
63-
prob = cache.prob
64-
if cache.retcode == ReturnCode.InitialFailure
65-
return SciMLBase.build_solution(prob, alg, prob.u0,
66-
NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode)
67-
end
6863
prob = convert(ImmutableNonlinearProblem, prob)
6964
return solve(prob, alg, args...; kwargs...)
7065
end
@@ -73,9 +68,7 @@ function CommonSolve.solve(
7368
prob::DualNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
7469
args...; kwargs...
7570
)
76-
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
77-
@set! alg.autodiff = AutoForwardDiff()
78-
end
71+
alg = configure_autodiff(prob, alg)
7972
prob = convert(ImmutableNonlinearProblem, prob)
8073
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
8174
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
@@ -88,9 +81,7 @@ function CommonSolve.solve(
8881
prob::DualNonlinearLeastSquaresProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
8982
args...; kwargs...
9083
)
91-
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
92-
@set! alg.autodiff = AutoForwardDiff()
93-
end
84+
alg = configure_autodiff(prob, alg)
9485
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
9586
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
9687
return SciMLBase.build_solution(
@@ -103,6 +94,7 @@ function CommonSolve.solve(
10394
alg::AbstractSimpleNonlinearSolveAlgorithm,
10495
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...
10596
)
97+
alg = configure_autodiff(prob, alg)
10698
cache = SciMLBase.__init(prob, alg, args...; kwargs...)
10799
prob = cache.prob
108100
if cache.retcode == ReturnCode.InitialFailure

lib/SimpleNonlinearSolve/src/halley.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,20 @@ A low-overhead implementation of Halley's Method.
2020
autodiff = nothing
2121
end
2222

23+
function configure_autodiff(prob, alg::SimpleHalley)
24+
autodiff = something(alg.autodiff, AutoForwardDiff())
25+
autodiff = SciMLBase.has_jac(prob.f) ? autodiff :
26+
NonlinearSolveBase.select_jacobian_autodiff(prob, autodiff)
27+
@set! alg.autodiff = autodiff
28+
alg
29+
end
30+
2331
function SciMLBase.__solve(
2432
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
2533
abstol = nothing, reltol = nothing, maxiters = 1000,
2634
alias_u0 = false, termination_condition = nothing, kwargs...
2735
)
36+
autodiff = alg.autodiff
2837
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
2938
fx = NLBUtils.evaluate_f(prob, x)
3039
T = promote_type(eltype(fx), eltype(x))
@@ -36,23 +45,21 @@ function SciMLBase.__solve(
3645
prob, abstol, reltol, fx, x, termination_condition, Val(:simple)
3746
)
3847

39-
# The way we write the 2nd order derivatives, we know Enzyme won't work there
40-
autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff
41-
@set! alg.autodiff = autodiff
42-
4348
@bb xo = copy(x)
4449

50+
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
51+
NLBUtils.safe_similar(fx) : fx
52+
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
53+
4554
if NLBUtils.can_setindex(x)
46-
A = NLBUtils.safe_similar(x, length(x), length(x))
4755
Aaᵢ = NLBUtils.safe_similar(x, length(x))
4856
cᵢ = NLBUtils.safe_similar(x)
4957
else
50-
A, Aaᵢ, cᵢ = x, x, x
58+
Aaᵢ, cᵢ = x, x, x
5159
end
5260

61+
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
5362
for _ in 1:maxiters
54-
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
55-
5663
NLBUtils.can_setindex(x) || (A = J)
5764

5865
# Factorize Once and Reuse
@@ -67,13 +74,8 @@ function SciMLBase.__solve(
6774
end
6875

6976
aᵢ = J_fact \ NLBUtils.safe_vec(fx)
70-
A_ = NLBUtils.safe_vec(A)
71-
@bb A_ = H × aᵢ
72-
A = NLBUtils.restructure(A, A_)
73-
74-
@bb Aaᵢ = A × aᵢ
75-
@bb A .*= -1
76-
bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ)
77+
hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, x, aᵢ)
78+
bᵢ = J_fact \ NLBUtils.safe_vec(hvvp)
7779

7880
cᵢ_ = NLBUtils.safe_vec(cᵢ)
7981
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))
@@ -84,6 +86,9 @@ function SciMLBase.__solve(
8486

8587
@bb @. x += cᵢ
8688
@bb copyto!(xo, x)
89+
90+
fx = NLBUtils.evaluate_f!!(prob, fx, x)
91+
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
8792
end
8893

8994
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)

lib/SimpleNonlinearSolve/src/raphson.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ end
2323

2424
const SimpleGaussNewton = SimpleNewtonRaphson
2525

26+
function configure_autodiff(prob, alg::SimpleNewtonRaphson)
27+
autodiff = something(alg.autodiff, AutoForwardDiff())
28+
autodiff = SciMLBase.has_jac(prob.f) ? autodiff :
29+
NonlinearSolveBase.select_jacobian_autodiff(prob, autodiff)
30+
@set! alg.autodiff = autodiff
31+
alg
32+
end
33+
2634
function SciMLBase.__solve(
2735
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
2836
alg::SimpleNewtonRaphson, args...;
2937
abstol = nothing, reltol = nothing, maxiters = 1000,
3038
alias_u0 = false, termination_condition = nothing, kwargs...
3139
)
40+
autodiff = alg.autodiff
3241
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
3342
fx = NLBUtils.evaluate_f(prob, x)
3443

@@ -39,10 +48,6 @@ function SciMLBase.__solve(
3948
prob, abstol, reltol, fx, x, termination_condition, Val(:simple)
4049
)
4150

42-
autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff :
43-
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
44-
@set! alg.autodiff = autodiff
45-
4651
@bb xo = similar(x)
4752
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
4853
NLBUtils.safe_similar(fx) : fx

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,26 +158,20 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, ::DINoPreparation)
158158
return J
159159
end
160160

161-
function compute_jacobian_and_hessian(autodiff, prob, _, x::Number)
161+
function compute_hvvp(prob, autodiff, _, x::Number, dir::Number)
162162
H = DI.second_derivative(prob.f, autodiff, x, Constant(prob.p))
163-
fx, J = DI.value_and_derivative(prob.f, autodiff, x, Constant(prob.p))
164-
return fx, J, H
163+
return H*dir
165164
end
166-
function compute_jacobian_and_hessian(autodiff, prob, fx, x)
167-
if SciMLBase.isinplace(prob)
168-
jac_fn = @closure (u, p) -> begin
165+
function compute_hvvp(prob, autodiff, fx, x, dir)
166+
jvp_fn = if SciMLBase.isinplace(prob)
167+
@closure (u, p) -> begin
169168
du = NLBUtils.safe_similar(fx, promote_type(eltype(fx), eltype(u)))
170-
return DI.jacobian(prob.f, du, autodiff, u, Constant(p))
169+
return only(DI.pushforward(prob.f, du, autodiff, u, (dir,), Constant(p)))
171170
end
172-
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
173-
fx = NLBUtils.evaluate_f!!(prob, fx, x)
174-
return fx, J, H
175171
else
176-
jac_fn = @closure (u, p) -> DI.jacobian(prob.f, autodiff, u, Constant(p))
177-
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
178-
fx = NLBUtils.evaluate_f!!(prob, fx, x)
179-
return fx, J, H
172+
@closure (u, p) -> only(DI.pushforward(prob.f, autodiff, u, (dir,), Constant(p)))
180173
end
174+
only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(prob.p)))
181175
end
182176

183177
end

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
]
2929

3030
function run_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F}
31-
return solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9)
31+
return @inferred solve(NonlinearProblem{false}(f, u0, p), solver; abstol = 1e-9)
3232
end
3333
function run_nlsolve_iip(f!::F, u0, p = 2.0; solver) where {F}
34-
return solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9)
34+
return @inferred solve(NonlinearProblem{true}(f!, u0, p), solver; abstol = 1e-9)
3535
end
3636
end
3737

0 commit comments

Comments
 (0)