Skip to content

Commit c91b973

Browse files
committed
feat: support NLLS forward AD
1 parent 7f15b30 commit c91b973

File tree

3 files changed

+112
-8
lines changed

3 files changed

+112
-8
lines changed

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
module NonlinearSolveBaseForwardDiffExt
22

33
using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
4+
using ArrayInterface: ArrayInterface
45
using CommonSolve: solve
6+
using DifferentiationInterface: DifferentiationInterface, Constant
57
using FastClosures: @closure
68
using ForwardDiff: ForwardDiff, Dual
9+
using LinearAlgebra: mul!
710
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
8-
NonlinearProblem,
9-
NonlinearLeastSquaresProblem, remake
11+
NonlinearProblem, NonlinearLeastSquaresProblem, remake
1012

1113
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
1214

15+
const DI = DifferentiationInterface
16+
1317
function NonlinearSolveBase.additional_incompatible_backend_check(
1418
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
1519
return !ForwardDiff.can_dual(eltype(prob.u0))
@@ -50,22 +54,108 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
5054
return sol, partials
5155
end
5256

57+
function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
58+
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
59+
p = Utils.value(prob.p)
60+
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
61+
sol = solve(newprob, alg, args...; kwargs...)
62+
uu = sol.u
63+
64+
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
65+
# nested autodiff as the last resort
66+
if SciMLBase.has_vjp(prob.f)
67+
if SciMLBase.isinplace(prob)
68+
vjp_fn = @closure (du, u, p) -> begin
69+
resid = Utils.safe_similar(du, length(sol.resid))
70+
prob.f(resid, u, p)
71+
prob.f.vjp(du, resid, u, p)
72+
du .*= 2
73+
return nothing
74+
end
75+
else
76+
vjp_fn = @closure (u, p) -> begin
77+
resid = prob.f(u, p)
78+
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
79+
end
80+
end
81+
elseif SciMLBase.has_jac(prob.f)
82+
if SciMLBase.isinplace(prob)
83+
vjp_fn = @closure (du, u, p) -> begin
84+
J = Utils.safe_similar(du, length(sol.resid), length(u))
85+
prob.f.jac(J, u, p)
86+
resid = Utils.safe_similar(du, length(sol.resid))
87+
prob.f(resid, u, p)
88+
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
89+
return nothing
90+
end
91+
else
92+
vjp_fn = @closure (u, p) -> begin
93+
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
94+
end
95+
end
96+
else
97+
# For small problems, nesting ForwardDiff is actually quite fast
98+
autodiff = length(uu) + length(sol.resid) 50 ?
99+
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
100+
AutoForwardDiff()
101+
102+
if SciMLBase.isinplace(prob)
103+
vjp_fn = @closure (du, u, p) -> begin
104+
resid = Utils.safe_similar(du, length(sol.resid))
105+
prob.f(resid, u, p)
106+
# Using `Constant` lead to dual ordering issues
107+
ff = @closure (du, u) -> prob.f(du, u, p)
108+
resid2 = copy(resid)
109+
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
110+
@. du *= 2
111+
return nothing
112+
end
113+
else
114+
vjp_fn = @closure (u, p) -> begin
115+
v = prob.f(u, p)
116+
# Using `Constant` lead to dual ordering issues
117+
ff = Base.Fix2(prob.f, p)
118+
res = only(DI.pullback(ff, autodiff, u, (v,)))
119+
ArrayInterface.can_setindex(res) || return 2 .* res
120+
@. res *= 2
121+
return res
122+
end
123+
end
124+
end
125+
126+
Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
127+
Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
128+
z = -Jᵤ \ Jₚ
129+
pp = prob.p
130+
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
131+
132+
if uu isa Number
133+
partials = sum(sumfun, zip(z, pp))
134+
elseif p isa Number
135+
partials = sumfun((z, pp))
136+
else
137+
partials = sum(sumfun, zip(eachcol(z), pp))
138+
end
139+
140+
return sol, partials
141+
end
142+
53143
function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
54144
if SciMLBase.isinplace(prob)
55-
f = @closure p -> begin
145+
f2 = @closure p -> begin
56146
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
57147
f(du, u, p)
58148
return du
59149
end
60150
else
61-
f = Base.Fix1(f, u)
151+
f2 = Base.Fix1(f, u)
62152
end
63153
if p isa Number
64-
return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1)
154+
return Utils.safe_reshape(ForwardDiff.derivative(f2, p), :, 1)
65155
elseif u isa Number
66-
return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :)
156+
return Utils.safe_reshape(ForwardDiff.gradient(f2, p), 1, :)
67157
else
68-
return ForwardDiff.jacobian(f, p)
158+
return ForwardDiff.jacobian(f2, p)
69159
end
70160
end
71161

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ function CommonSolve.solve(
6161
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
6262
end
6363

64+
function CommonSolve.solve(
65+
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
66+
<:Union{
67+
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
68+
alg::AbstractSimpleNonlinearSolveAlgorithm,
69+
args...;
70+
kwargs...) where {T, V, P, iip}
71+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
72+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
73+
return SciMLBase.build_solution(
74+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
75+
end
76+
6477
function CommonSolve.solve(
6578
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
6679
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)

lib/SimpleNonlinearSolve/src/raphson.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ end
2424
const SimpleGaussNewton = SimpleNewtonRaphson
2525

2626
function SciMLBase.__solve(
27-
prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...;
27+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
28+
alg::SimpleNewtonRaphson, args...;
2829
abstol = nothing, reltol = nothing, maxiters = 1000,
2930
alias_u0 = false, termination_condition = nothing, kwargs...)
3031
x = Utils.maybe_unaliased(prob.u0, alias_u0)

0 commit comments

Comments
 (0)