Skip to content

Commit 4542543

Browse files
committed
Add scalar nonlinear solve AD
1 parent 08c9701 commit 4542543

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module NonlinearSolve
33
using Reexport
44
using UnPack: @unpack
55
using FiniteDiff, ForwardDiff
6+
using ForwardDiff: Dual
67
using Setfield
78
using StaticArrays
89
using RecursiveArrayTools

src/scalar.jl

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,49 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol
1919
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

22-
function solve(prob::NonlinearProblem{<:Number, iip, <:ForwardDiff.Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {uType, iip, T, V, P}
22+
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2323
f = prob.f
24-
p = ForwardDiff.value(prob.p)
25-
u0 = ForwardDiff.value(prob.u0)
24+
p = value(prob.p)
25+
u0 = value(prob.u0)
26+
2627
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
2728
sol = solve(newprob, alg, args...; kwargs...)
28-
f_p = ForwardDiff.derivative(Base.Fix1(f, sol.u), p)
29-
f_x = ForwardDiff.derivative(Base.Fix2(f, p), sol.u)
30-
partials = (-f_p / f_x) * ForwardDiff.partials(prob.p)
31-
return NewtonSolution(ForwardDiff.Dual{T,V,P}(sol.u, partials), sol.retcode)
29+
30+
uu = getsolution(sol)
31+
if p isa Number
32+
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
33+
else
34+
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
35+
end
36+
37+
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
38+
pp = prob.p
39+
sumfun = let f_x′ = -f_x
40+
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
41+
end
42+
partials = sum(sumfun, zip(f_p, pp))
43+
return sol, partials
3244
end
3345

34-
function solve(prob::NonlinearProblem{uType, iip, <:ForwardDiff.Dual{T,V,P}}, alg::Bisection, args...; kwargs...) where {uType, iip, T, V, P}
35-
prob_nodual = NonlinearProblem(prob.f, prob.u0, ForwardDiff.value(prob.p); prob.kwargs...)
36-
sol = solve(prob_nodual, alg, args...; kwargs...)
37-
# f, x and p always satisfy
38-
# f(x, p) = 0
39-
# dx * f_x(x, p) + dp * f_p(x, p) = 0
40-
# dx / dp = - f_p(x, p) / f_x(x, p)
41-
f_p = (p) -> prob.f(sol.left, p)
42-
f_x = (x) -> prob.f(x, ForwardDiff.value(prob.p))
43-
d_p = ForwardDiff.derivative(f_p, ForwardDiff.value(prob.p))
44-
d_x = ForwardDiff.derivative(f_x, sol.left)
45-
partials = - d_p / d_x * ForwardDiff.partials(prob.p)
46-
return BracketingSolution(ForwardDiff.Dual{T,V,P}(sol.left, partials), ForwardDiff.Dual{T,V,P}(sol.right, partials), sol.retcode)
46+
function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
47+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
48+
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
49+
end
50+
function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
51+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
52+
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
4753
end
4854

49-
# still WIP
50-
function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:ForwardDiff.Dual{T,V,P}, N}}, alg::Bisection, args...; kwargs...) where {uType, iip, T, V, P, N}
51-
p_nodual = ForwardDiff.value.(prob.p)
52-
prob_nodual = NonlinearProblem(prob.f, prob.u0, p_nodual; prob.kwargs...)
53-
sol = solve(prob_nodual, alg, args...; kwargs...)
54-
# f, x and p always satisfy
55-
# f(x, p) = 0
56-
# dx * f_x(x, p) + dp * f_p(x, p) = 0
57-
# dx / dp = - f_p(x, p) / f_x(x, p)
58-
f_p = (p) -> [ prob.f(sol.left, p) ]
59-
f_x = (x) -> prob.f(x, p_nodual)
60-
d_p = ForwardDiff.jacobian(f_p, p_nodual)
61-
d_x = ForwardDiff.derivative(f_x, sol.left)
62-
@. d_p = - d_p / d_x
63-
@show ForwardDiff.partials.(prob.p)
64-
return ForwardDiff.Dual{T,V,P}(sol.left, d_p * ForwardDiff.partials.(prob.p))
55+
# avoid ambiguities
56+
for Alg in [Bisection, Falsi]
57+
@eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
58+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
59+
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
60+
end
61+
@eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
62+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
63+
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
64+
end
6565
end
6666

6767
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)

src/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ function sync_residuals!(solver::BracketingImmutableSolver)
7878
@set! solver.fr = solver.f(solver.right, solver.p)
7979
solver
8080
end
81+
82+
getsolution(sol::NewtonSolution) = sol.u
83+
getsolution(sol::BracketingSolution) = sol.left

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,7 @@ function value_derivative(f::F, x::R) where {F,R}
234234
out = f(ForwardDiff.Dual{T}(x, one(x)))
235235
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
236236
end
237+
238+
value(x) = x
239+
value(x::Dual) = ForwardDiff.value(x)
240+
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

0 commit comments

Comments
 (0)