Skip to content

Commit b8aca89

Browse files
committed
Add tests for levenberg least squares
1 parent 6e2f58d commit b8aca89

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

test/nonlinear_least_squares.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random
2+
3+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
4+
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
5+
6+
θ_true = [1.0, 0.1, 2.0, 0.5]
7+
8+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
9+
10+
y_target = true_function(x, θ_true)
11+
12+
function loss_function(θ, p)
13+
= true_function(p, θ)
14+
return abs2.(ŷ .- y_target)
15+
end
16+
17+
function loss_function(resid, θ, p)
18+
true_function(resid, p, θ)
19+
resid .= abs2.(resid .- y_target)
20+
return resid
21+
end
22+
23+
θ_init = θ_true .+ randn!(similar(θ_true)) * 0.1
24+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
25+
prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
26+
resid_prototype = zero(y_target)), θ_init, x)
27+
28+
# sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8)
29+
# @test SciMLBase.successful_retcode(sol)
30+
# @test norm(sol.resid) < 1e-6
31+
32+
# sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8)
33+
# @test SciMLBase.successful_retcode(sol)
34+
# @test norm(sol.resid) < 1e-6
35+
36+
sol = solve(prob_oop, LevenbergMarquardt(); maxiters = 1000, abstol = 1e-8)
37+
@test SciMLBase.successful_retcode(sol)
38+
@test norm(sol.resid) < 1e-6
39+
40+
sol = solve(prob_iip, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization());
41+
maxiters = 1000, abstol = 1e-8)
42+
@test SciMLBase.successful_retcode(sol)
43+
@test norm(sol.resid) < 1e-6

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ end
1515
if GROUP == "All" || GROUP == "Core"
1616
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
18+
19+
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
1820
end
1921

2022
if GROUP == "All" || GROUP == "23TestProblems"

0 commit comments

Comments
 (0)