@@ -5,8 +5,17 @@ function SciMLBase.solve(
55 sol, partials = __nlsolve_ad(prob, alg, args... ; kwargs... )
66 dual_soln = __nlsolve_dual_soln(sol. u, partials, prob. p)
77 return SciMLBase. build_solution(
8- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
9- sol. original)
8+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
9+ end
10+
11+ function SciMLBase. solve(
12+ prob:: NonlinearLeastSquaresProblem {<: AbstractArray ,
13+ iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
14+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
15+ sol, partials = __nlsolve_ad(prob, alg, args... ; kwargs... )
16+ dual_soln = __nlsolve_dual_soln(sol. u, partials, prob. p)
17+ return SciMLBase. build_solution(
18+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1019end
1120
1221for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
2433 end
2534end
2635
27- function __nlsolve_ad(prob, alg, args... ; kwargs... )
36+ function __nlsolve_ad(
37+ prob:: Union{IntervalNonlinearProblem, NonlinearProblem} , alg, args... ; kwargs... )
2838 p = value(prob. p)
2939 if prob isa IntervalNonlinearProblem
3040 tspan = value.(prob. tspan)
@@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
5565 return sol, partials
5666end
5767
68+ function __nlsolve_ad(prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
69+ p = value(prob. p)
70+ u0 = value(prob. u0)
71+ newprob = NonlinearLeastSquaresProblem(prob. f, u0, p; prob. kwargs... )
72+
73+ sol = solve(newprob, alg, args... ; kwargs... )
74+
75+ uu = sol. u
76+
77+ # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78+ # nested autodiff as the last resort
79+ if SciMLBase. has_vjp(prob. f)
80+ if isinplace(prob)
81+ _F = @closure (du, u, p) -> begin
82+ resid = similar(du, length(sol. resid))
83+ prob. f(resid, u, p)
84+ prob. f. vjp(du, resid, u, p)
85+ du .*= 2
86+ return nothing
87+ end
88+ else
89+ _F = @closure (u, p) -> begin
90+ resid = prob. f(u, p)
91+ return reshape(2 .* prob. f. vjp(resid, u, p), size(u))
92+ end
93+ end
94+ elseif SciMLBase. has_jac(prob. f)
95+ if isinplace(prob)
96+ _F = @closure (du, u, p) -> begin
97+ J = similar(du, length(sol. resid), length(u))
98+ prob. f. jac(J, u, p)
99+ resid = similar(du, length(sol. resid))
100+ prob. f(resid, u, p)
101+ mul!(reshape(du, 1 , :), vec(resid)' , J, 2 , false )
102+ return nothing
103+ end
104+ else
105+ _F = @closure (u, p) -> begin
106+ return reshape(2 .* vec(prob. f(u, p))' * prob. f. jac(u, p), size(u))
107+ end
108+ end
109+ else
110+ if isinplace(prob)
111+ _F = @closure (du, u, p) -> begin
112+ resid = similar(du, length(sol. resid))
113+ res = DiffResults. DiffResult(
114+ resid, similar(du, length(sol. resid), length(u)))
115+ _f = @closure (du, u) -> prob. f(du, u, p)
116+ ForwardDiff. jacobian!(res, _f, resid, u)
117+ mul!(reshape(du, 1 , :), vec(DiffResults. value(res))' ,
118+ DiffResults. jacobian(res), 2 , false )
119+ return nothing
120+ end
121+ else
122+ # For small problems, nesting ForwardDiff is actually quite fast
123+ if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol. resid) ≥ 50 )
124+ _F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob. f, u, p)
125+ else
126+ _F = @closure (u, p) -> begin
127+ T = promote_type(eltype(u), eltype(p))
128+ res = DiffResults. DiffResult(
129+ similar(u, T, size(sol. resid)), similar(
130+ u, T, length(sol. resid), length(u)))
131+ ForwardDiff. jacobian!(res, Base. Fix2(prob. f, p), u)
132+ return reshape(
133+ 2 .* vec(DiffResults. value(res))' * DiffResults. jacobian(res),
134+ size(u))
135+ end
136+ end
137+ end
138+ end
139+
140+ f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
141+ f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
142+
143+ z_arr = - f_x \ f_p
144+
145+ pp = prob. p
146+ sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff. partials(p), z)
147+ if uu isa Number
148+ partials = sum(sumfun, zip(z_arr, pp))
149+ elseif p isa Number
150+ partials = sumfun((z_arr, pp))
151+ else
152+ partials = sum(sumfun, zip(eachcol(z_arr), pp))
153+ end
154+
155+ return sol, partials
156+ end
157+
58158@inline function __nlsolve_∂f_∂p(prob, f:: F , u, p) where {F}
59159 if isinplace(prob)
60160 __f = p -> begin
0 commit comments