Skip to content

Commit e032a13

Browse files
committed
test: NLLS forwarddiff rules testing
1 parent 2dcdbbd commit e032a13

File tree

5 files changed

+122
-6
lines changed

5 files changed

+122
-6
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ function CommonSolve.solve(
7575
end
7676

7777
function CommonSolve.solve(
78-
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
78+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
79+
alg::AbstractSimpleNonlinearSolveAlgorithm,
7980
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
8081
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
8182
sensealg = prob.kwargs[:sensealg]
@@ -86,7 +87,8 @@ function CommonSolve.solve(
8687
p === nothing, alg, args...; prob.kwargs..., kwargs...)
8788
end
8889

89-
function simplenonlinearsolve_solve_up(prob::ImmutableNonlinearProblem, sensealg, u0,
90+
function simplenonlinearsolve_solve_up(
91+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
9092
u0_changed, p, p_changed, alg, args...; kwargs...)
9193
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
9294
return SciMLBase.__solve(prob, alg, args...; kwargs...)

lib/SimpleNonlinearSolve/src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function SciMLBase.__solve(
4343

4444
@bb xo = similar(x)
4545
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
46-
safe_similar(fx) : nothing
46+
safe_similar(fx) : fx
4747
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
4848
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
4949

lib/SimpleNonlinearSolve/src/trust_region.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi
9494

9595
@bb xo = copy(x)
9696
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
97-
safe_similar(fx) : nothing
97+
safe_similar(fx) : fx
9898
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
9999
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
100100

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
183183
end
184184
if extras isa AnalyticJacobian
185185
if SciMLBase.isinplace(prob)
186-
prob.jac(J, x, prob.p)
186+
prob.f.jac(J, x, prob.p)
187187
return J
188188
else
189-
return prob.jac(x, prob.p)
189+
return prob.f.jac(x, prob.p)
190190
end
191191
end
192192
if SciMLBase.isinplace(prob)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,115 @@
1+
@testitem "ForwardDiff.jl Integration NonlinearLeastSquaresProblem" tags=[:core] begin
2+
using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
3+
Zygote, ReverseDiff
4+
using DifferentiationInterface
15

6+
const DI = DifferentiationInterface
7+
8+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
9+
10+
θ_true = [1.0, 0.1, 2.0, 0.5]
11+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
12+
y_target = true_function(x, θ_true)
13+
14+
loss_function(θ, p) = true_function(p, θ) .- y_target
15+
16+
loss_function_jac(θ, p) = ForwardDiff.jacobian(Base.Fix2(loss_function, p), θ)
17+
18+
loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
19+
20+
function loss_function!(resid, θ, p)
21+
= true_function(p, θ)
22+
@. resid =- y_target
23+
return
24+
end
25+
26+
function loss_function_jac!(J, θ, p)
27+
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
28+
return
29+
end
30+
31+
function loss_function_vjp!(vJ, v, θ, p)
32+
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
33+
return
34+
end
35+
36+
θ_init = θ_true .+ 0.1
37+
38+
@testset for alg in (
39+
SimpleGaussNewton(),
40+
SimpleGaussNewton(; autodiff = AutoForwardDiff()),
41+
SimpleGaussNewton(; autodiff = AutoFiniteDiff()),
42+
SimpleGaussNewton(; autodiff = AutoReverseDiff())
43+
)
44+
function obj_1(p)
45+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
46+
sol = solve(prob_oop, alg)
47+
return sum(abs2, sol.u)
48+
end
49+
50+
function obj_2(p)
51+
ff = NonlinearFunction{false}(
52+
loss_function; resid_prototype = zeros(length(y_target)))
53+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
54+
sol = solve(prob_oop, alg)
55+
return sum(abs2, sol.u)
56+
end
57+
58+
function obj_3(p)
59+
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
60+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
61+
sol = solve(prob_oop, alg)
62+
return sum(abs2, sol.u)
63+
end
64+
65+
finitediff = DI.gradient(obj_1, AutoFiniteDiff(), x)
66+
67+
fdiff1 = DI.gradient(obj_1, AutoForwardDiff(), x)
68+
fdiff2 = DI.gradient(obj_2, AutoForwardDiff(), x)
69+
fdiff3 = DI.gradient(obj_3, AutoForwardDiff(), x)
70+
71+
@test finitedifffdiff1 atol=1e-5
72+
@test finitedifffdiff2 atol=1e-5
73+
@test finitedifffdiff3 atol=1e-5
74+
@test fdiff1 fdiff2 fdiff3
75+
76+
function obj_4(p)
77+
prob_iip = NonlinearLeastSquaresProblem(
78+
NonlinearFunction{true}(
79+
loss_function!; resid_prototype = zeros(length(y_target))),
80+
θ_init,
81+
p)
82+
sol = solve(prob_iip, alg)
83+
return sum(abs2, sol.u)
84+
end
85+
86+
function obj_5(p)
87+
ff = NonlinearFunction{true}(
88+
loss_function!; resid_prototype = zeros(length(y_target)),
89+
jac = loss_function_jac!)
90+
prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p)
91+
sol = solve(prob_iip, alg)
92+
return sum(abs2, sol.u)
93+
end
94+
95+
function obj_6(p)
96+
ff = NonlinearFunction{true}(
97+
loss_function!; resid_prototype = zeros(length(y_target)),
98+
vjp = loss_function_vjp!)
99+
prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p)
100+
sol = solve(prob_iip, alg)
101+
return sum(abs2, sol.u)
102+
end
103+
104+
finitediff = DI.gradient(obj_4, AutoFiniteDiff(), x)
105+
106+
fdiff4 = DI.gradient(obj_4, AutoForwardDiff(), x)
107+
fdiff5 = DI.gradient(obj_5, AutoForwardDiff(), x)
108+
fdiff6 = DI.gradient(obj_6, AutoForwardDiff(), x)
109+
110+
@test finitedifffdiff4 atol=1e-5
111+
@test finitedifffdiff5 atol=1e-5
112+
@test finitedifffdiff6 atol=1e-5
113+
@test fdiff4 fdiff5 fdiff6
114+
end
115+
end

0 commit comments

Comments
 (0)