|
1 | 1 | module NonlinearSolveBaseForwardDiffExt
|
2 | 2 |
|
3 | 3 | using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
|
| 4 | +using ArrayInterface: ArrayInterface |
4 | 5 | using CommonSolve: solve
|
| 6 | +using DifferentiationInterface: DifferentiationInterface, Constant |
5 | 7 | using FastClosures: @closure
|
6 | 8 | using ForwardDiff: ForwardDiff, Dual
|
| 9 | +using LinearAlgebra: mul! |
7 | 10 | using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
|
8 |
| - NonlinearProblem, |
9 |
| - NonlinearLeastSquaresProblem, remake |
| 11 | + NonlinearProblem, NonlinearLeastSquaresProblem, remake |
10 | 12 |
|
11 | 13 | using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
|
12 | 14 |
|
| 15 | +const DI = DifferentiationInterface |
| 16 | + |
13 | 17 | function NonlinearSolveBase.additional_incompatible_backend_check(
|
14 | 18 | prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
|
15 | 19 | return !ForwardDiff.can_dual(eltype(prob.u0))
|
@@ -50,22 +54,108 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
|
50 | 54 | return sol, partials
|
51 | 55 | end
|
52 | 56 |
|
| 57 | +function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( |
| 58 | + prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...) |
| 59 | + p = Utils.value(prob.p) |
| 60 | + newprob = remake(prob; p, u0 = Utils.value(prob.u0)) |
| 61 | + sol = solve(newprob, alg, args...; kwargs...) |
| 62 | + uu = sol.u |
| 63 | + |
| 64 | + # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use |
| 65 | + # nested autodiff as the last resort |
| 66 | + if SciMLBase.has_vjp(prob.f) |
| 67 | + if SciMLBase.isinplace(prob) |
| 68 | + vjp_fn = @closure (du, u, p) -> begin |
| 69 | + resid = Utils.safe_similar(du, length(sol.resid)) |
| 70 | + prob.f(resid, u, p) |
| 71 | + prob.f.vjp(du, resid, u, p) |
| 72 | + du .*= 2 |
| 73 | + return nothing |
| 74 | + end |
| 75 | + else |
| 76 | + vjp_fn = @closure (u, p) -> begin |
| 77 | + resid = prob.f(u, p) |
| 78 | + return reshape(2 .* prob.f.vjp(resid, u, p), size(u)) |
| 79 | + end |
| 80 | + end |
| 81 | + elseif SciMLBase.has_jac(prob.f) |
| 82 | + if SciMLBase.isinplace(prob) |
| 83 | + vjp_fn = @closure (du, u, p) -> begin |
| 84 | + J = Utils.safe_similar(du, length(sol.resid), length(u)) |
| 85 | + prob.f.jac(J, u, p) |
| 86 | + resid = Utils.safe_similar(du, length(sol.resid)) |
| 87 | + prob.f(resid, u, p) |
| 88 | + mul!(reshape(du, 1, :), vec(resid)', J, 2, false) |
| 89 | + return nothing |
| 90 | + end |
| 91 | + else |
| 92 | + vjp_fn = @closure (u, p) -> begin |
| 93 | + return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u)) |
| 94 | + end |
| 95 | + end |
| 96 | + else |
| 97 | + # For small problems, nesting ForwardDiff is actually quite fast |
| 98 | + autodiff = length(uu) + length(sol.resid) ≥ 50 ? |
| 99 | + NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) : |
| 100 | + AutoForwardDiff() |
| 101 | + |
| 102 | + if SciMLBase.isinplace(prob) |
| 103 | + vjp_fn = @closure (du, u, p) -> begin |
| 104 | + resid = Utils.safe_similar(du, length(sol.resid)) |
| 105 | + prob.f(resid, u, p) |
| 106 | + # Using `Constant` lead to dual ordering issues |
| 107 | + ff = @closure (du, u) -> prob.f(du, u, p) |
| 108 | + resid2 = copy(resid) |
| 109 | + DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,)) |
| 110 | + @. du *= 2 |
| 111 | + return nothing |
| 112 | + end |
| 113 | + else |
| 114 | + vjp_fn = @closure (u, p) -> begin |
| 115 | + v = prob.f(u, p) |
| 116 | + # Using `Constant` lead to dual ordering issues |
| 117 | + ff = Base.Fix2(prob.f, p) |
| 118 | + res = only(DI.pullback(ff, autodiff, u, (v,))) |
| 119 | + ArrayInterface.can_setindex(res) || return 2 .* res |
| 120 | + @. res *= 2 |
| 121 | + return res |
| 122 | + end |
| 123 | + end |
| 124 | + end |
| 125 | + |
| 126 | + Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p) |
| 127 | + Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p) |
| 128 | + z = -Jᵤ \ Jₚ |
| 129 | + pp = prob.p |
| 130 | + sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) |
| 131 | + |
| 132 | + if uu isa Number |
| 133 | + partials = sum(sumfun, zip(z, pp)) |
| 134 | + elseif p isa Number |
| 135 | + partials = sumfun((z, pp)) |
| 136 | + else |
| 137 | + partials = sum(sumfun, zip(eachcol(z), pp)) |
| 138 | + end |
| 139 | + |
| 140 | + return sol, partials |
| 141 | +end |
| 142 | + |
53 | 143 | function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
|
54 | 144 | if SciMLBase.isinplace(prob)
|
55 |
| - f = @closure p -> begin |
| 145 | + f2 = @closure p -> begin |
56 | 146 | du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
|
57 | 147 | f(du, u, p)
|
58 | 148 | return du
|
59 | 149 | end
|
60 | 150 | else
|
61 |
| - f = Base.Fix1(f, u) |
| 151 | + f2 = Base.Fix1(f, u) |
62 | 152 | end
|
63 | 153 | if p isa Number
|
64 |
| - return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1) |
| 154 | + return Utils.safe_reshape(ForwardDiff.derivative(f2, p), :, 1) |
65 | 155 | elseif u isa Number
|
66 |
| - return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :) |
| 156 | + return Utils.safe_reshape(ForwardDiff.gradient(f2, p), 1, :) |
67 | 157 | else
|
68 |
| - return ForwardDiff.jacobian(f, p) |
| 158 | + return ForwardDiff.jacobian(f2, p) |
69 | 159 | end
|
70 | 160 | end
|
71 | 161 |
|
|
0 commit comments