Skip to content

Commit 08c9701

Browse files
committed
Fix scalar Newton AD
1 parent d5d3dae commit 08c9701

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/scalar.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol
1919
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

22+
function solve(prob::NonlinearProblem{<:Number, iip, <:ForwardDiff.Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {uType, iip, T, V, P}
23+
f = prob.f
24+
p = ForwardDiff.value(prob.p)
25+
u0 = ForwardDiff.value(prob.u0)
26+
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
27+
sol = solve(newprob, alg, args...; kwargs...)
28+
f_p = ForwardDiff.derivative(Base.Fix1(f, sol.u), p)
29+
f_x = ForwardDiff.derivative(Base.Fix2(f, p), sol.u)
30+
partials = (-f_p / f_x) * ForwardDiff.partials(prob.p)
31+
return NewtonSolution(ForwardDiff.Dual{T,V,P}(sol.u, partials), sol.retcode)
32+
end
33+
2234
function solve(prob::NonlinearProblem{uType, iip, <:ForwardDiff.Dual{T,V,P}}, alg::Bisection, args...; kwargs...) where {uType, iip, T, V, P}
2335
prob_nodual = NonlinearProblem(prob.f, prob.u0, ForwardDiff.value(prob.p); prob.kwargs...)
2436
sol = solve(prob_nodual, alg, args...; kwargs...)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ Move `x` one floating point towards x0.
216216
function prevfloat_tdir(x::T, x0::T, x1::T)::T where {T}
217217
x1 > x0 ? prevfloat(x) : nextfloat(x)
218218
end
219-
219+
220220
function nextfloat_tdir(x::T, x0::T, x1::T)::T where {T}
221221
x1 > x0 ? nextfloat(x) : prevfloat(x)
222222
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ end
5757
f, u0 = (u, p) -> u * u - p, 1.0
5858

5959
g = function (p)
60-
probN = NonlinearProblem{false}(f, u0, p)
60+
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
6161
sol = solve(probN, NewtonRaphson())
6262
return sol.u
6363
end
6464

65-
@test_broken ForwardDiff.derivative(g, 1.0) 0.5
65+
@test ForwardDiff.derivative(g, 1.0) 0.5
6666

6767
for p in 1.1:0.1:100.0
6868
@test g(p) sqrt(p)

0 commit comments

Comments
 (0)