Skip to content

Commit 09bb525

Browse files
authored
Merge pull request #11 from JuliaComputing/myb/newtonad
Scalar nonlinear solve AD
2 parents f8af971 + 3e21c45 commit 09bb525

File tree

6 files changed

+80
-3
lines changed

6 files changed

+80
-3
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Reexport = "0.2"
2020
Setfield = "0.7"
2121
StaticArrays = "0.11, 0.12"
2222
UnPack = "0.1, 1.0"
23+
julia = "1"
2324

2425
[extras]
2526
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module NonlinearSolve
33
using Reexport
44
using UnPack: @unpack
55
using FiniteDiff, ForwardDiff
6+
using ForwardDiff: Dual
67
using Setfield
78
using StaticArrays
89
using RecursiveArrayTools

src/scalar.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,51 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol
1919
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

22+
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
23+
f = prob.f
24+
p = value(prob.p)
25+
u0 = value(prob.u0)
26+
27+
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
28+
sol = solve(newprob, alg, args...; kwargs...)
29+
30+
uu = getsolution(sol)
31+
if p isa Number
32+
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
33+
else
34+
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
35+
end
36+
37+
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
38+
pp = prob.p
39+
sumfun = let f_x′ = -f_x
40+
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
41+
end
42+
partials = sum(sumfun, zip(f_p, pp))
43+
return sol, partials
44+
end
45+
46+
function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
47+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
48+
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
49+
end
50+
function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
51+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
52+
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
53+
end
54+
55+
# avoid ambiguities
56+
for Alg in [Bisection, Falsi]
57+
@eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
58+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
59+
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
60+
end
61+
@eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
62+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
63+
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
64+
end
65+
end
66+
2267
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
2368
f = Base.Fix2(prob.f, prob.p)
2469
left, right = prob.u0

src/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ function sync_residuals!(solver::BracketingImmutableSolver)
7878
@set! solver.fr = solver.f(solver.right, solver.p)
7979
solver
8080
end
81+
82+
getsolution(sol::NewtonSolution) = sol.u
83+
getsolution(sol::BracketingSolution) = sol.left

src/utils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ Move `x` one floating point towards x0.
216216
function prevfloat_tdir(x::T, x0::T, x1::T)::T where {T}
217217
x1 > x0 ? prevfloat(x) : nextfloat(x)
218218
end
219-
219+
220220
function nextfloat_tdir(x::T, x0::T, x1::T)::T where {T}
221221
x1 > x0 ? nextfloat(x) : prevfloat(x)
222222
end
@@ -234,3 +234,7 @@ function value_derivative(f::F, x::R) where {F,R}
234234
out = f(ForwardDiff.Dual{T}(x, one(x)))
235235
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
236236
end
237+
238+
value(x) = x
239+
value(x::Dual) = ForwardDiff.value(x)
240+
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

test/runtests.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,41 @@ end
5757
f, u0 = (u, p) -> u * u - p, 1.0
5858

5959
g = function (p)
60-
probN = NonlinearProblem{false}(f, u0, p)
60+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
6161
sol = solve(probN, NewtonRaphson())
6262
return sol.u
6363
end
6464

65-
@test_broken ForwardDiff.derivative(g, 1.0) 0.5
65+
@test ForwardDiff.derivative(g, 1.0) 0.5
6666

6767
for p in 1.1:0.1:100.0
6868
@test g(p) sqrt(p)
6969
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
7070
end
7171

72+
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
73+
t = (p) -> [sqrt(p[2] / p[1])]
74+
p = [0.9, 50.0]
75+
for alg in [Bisection(), Falsi()]
76+
global g, p
77+
g = function (p)
78+
probN = NonlinearProblem{false}(f, u0, p)
79+
sol = solve(probN, Bisection())
80+
return [sol.left]
81+
end
82+
83+
@test g(p) [sqrt(p[2] / p[1])]
84+
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
85+
end
86+
87+
gnewton = function (p)
88+
probN = NonlinearProblem{false}(f, 0.5, p)
89+
sol = solve(probN, NewtonRaphson())
90+
return [sol.u]
91+
end
92+
@test gnewton(p) [sqrt(p[2] / p[1])]
93+
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)
94+
7295
# Error Checks
7396

7497
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]

0 commit comments

Comments
 (0)