Skip to content

Commit 890ae75

Browse files
committed
Add tests
1 parent a9d884f commit 890ae75

File tree

3 files changed

+17
-23
lines changed

3 files changed

+17
-23
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ julia = "1.9"
5858
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5959
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6060
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
61+
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
6162
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6263
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
6364
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
@@ -71,4 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7172
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7273

7374
[targets]
74-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]
75+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim"]

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver
4141
f! = FunctionWrapper{iip}(prob.f, prob.p)
4242
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)
4343

44-
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!,
45-
J = prob.f.jac_prototype, alg.autodiff,
46-
output_length = length(prob.f.resid_prototype))
44+
resid_prototype = prob.f.resid_prototype === nothing ?
45+
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
46+
prob.f.resid_prototype
47+
48+
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
49+
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
4750
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
4851

4952
return LeastSquaresOptimCache(prob, alg, allocated_prob,

test/nonlinear_least_squares.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random
2+
import LeastSquaresOptim
23

34
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
45
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
@@ -25,22 +26,11 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
2526
prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
2627
resid_prototype = zero(y_target)), θ_init, x)
2728

28-
sol = solve(prob_oop, GaussNewton(; linsolve = NormalCholeskyFactorization());
29-
maxiters = 1000, abstol = 1e-8)
30-
@test SciMLBase.successful_retcode(sol)
31-
@test norm(sol.resid) < 1e-6
32-
33-
sol = solve(prob_iip, GaussNewton(; linsolve = NormalCholeskyFactorization());
34-
maxiters = 1000, abstol = 1e-8)
35-
@test SciMLBase.successful_retcode(sol)
36-
@test norm(sol.resid) < 1e-6
37-
38-
sol = solve(prob_oop, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization());
39-
maxiters = 1000, abstol = 1e-8)
40-
@test SciMLBase.successful_retcode(sol)
41-
@test norm(sol.resid) < 1e-6
42-
43-
sol = solve(prob_iip, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization());
44-
maxiters = 1000, abstol = 1e-8)
45-
@test SciMLBase.successful_retcode(sol)
46-
@test norm(sol.resid) < 1e-6
29+
nlls_problems = [prob_oop, prob_iip]
30+
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
31+
32+
for prob in nlls_problems, solver in solvers
33+
@time sol = solve(prob, solver; maxiters = 1000, abstol = 1e-8)
34+
@test SciMLBase.successful_retcode(sol)
35+
@test norm(sol.resid) < 1e-6
36+
end

0 commit comments

Comments
 (0)