Skip to content

Commit d5d3dae

Browse files
committed
[WIP] AD in Bisection
1 parent f8af971 commit d5d3dae

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/scalar.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol
1919
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

22+
function solve(prob::NonlinearProblem{uType, iip, <:ForwardDiff.Dual{T,V,P}}, alg::Bisection, args...; kwargs...) where {uType, iip, T, V, P}
23+
prob_nodual = NonlinearProblem(prob.f, prob.u0, ForwardDiff.value(prob.p); prob.kwargs...)
24+
sol = solve(prob_nodual, alg, args...; kwargs...)
25+
# f, x and p always satisfy
26+
# f(x, p) = 0
27+
# dx * f_x(x, p) + dp * f_p(x, p) = 0
28+
# dx / dp = - f_p(x, p) / f_x(x, p)
29+
f_p = (p) -> prob.f(sol.left, p)
30+
f_x = (x) -> prob.f(x, ForwardDiff.value(prob.p))
31+
d_p = ForwardDiff.derivative(f_p, ForwardDiff.value(prob.p))
32+
d_x = ForwardDiff.derivative(f_x, sol.left)
33+
partials = - d_p / d_x * ForwardDiff.partials(prob.p)
34+
return BracketingSolution(ForwardDiff.Dual{T,V,P}(sol.left, partials), ForwardDiff.Dual{T,V,P}(sol.right, partials), sol.retcode)
35+
end
36+
37+
# still WIP
38+
function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:ForwardDiff.Dual{T,V,P}, N}}, alg::Bisection, args...; kwargs...) where {uType, iip, T, V, P, N}
39+
p_nodual = ForwardDiff.value.(prob.p)
40+
prob_nodual = NonlinearProblem(prob.f, prob.u0, p_nodual; prob.kwargs...)
41+
sol = solve(prob_nodual, alg, args...; kwargs...)
42+
# f, x and p always satisfy
43+
# f(x, p) = 0
44+
# dx * f_x(x, p) + dp * f_p(x, p) = 0
45+
# dx / dp = - f_p(x, p) / f_x(x, p)
46+
f_p = (p) -> [ prob.f(sol.left, p) ]
47+
f_x = (x) -> prob.f(x, p_nodual)
48+
d_p = ForwardDiff.jacobian(f_p, p_nodual)
49+
d_x = ForwardDiff.derivative(f_x, sol.left)
50+
@. d_p = - d_p / d_x
51+
@show ForwardDiff.partials.(prob.p)
52+
return ForwardDiff.Dual{T,V,P}(sol.left, d_p * ForwardDiff.partials.(prob.p))
53+
end
54+
2255
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
2356
f = Base.Fix2(prob.f, prob.p)
2457
left, right = prob.u0

test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ for p in 1.1:0.1:100.0
6969
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
7070
end
7171

72+
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
73+
t = (p) -> [sqrt(p[2] / p[1])]
74+
g = function (p)
75+
probN = NonlinearProblem{false}(f, u0, p)
76+
sol = solve(probN, Bisection())
77+
return [sol.left]
78+
end
79+
80+
for p1 in 1.0:1.0:100.0
81+
for p2 in 1.0:1.0:100.0
82+
p = [p1, p2]
83+
@show p
84+
@test g(p) [sqrt(p[2] / p[1])]
85+
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
86+
end
87+
end
88+
7289
# Error Checks
7390

7491
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]

0 commit comments

Comments
 (0)