Skip to content

Commit 0a9551f

Browse files
AdityaPandeyCNChrisRackauckas
authored andcommitted
adeed nonlinear-solve wrappers code split out of OptimizationSciPy into NonlinearSolve
Signed-off-by: AdityaPandeyCN <[email protected]>
1 parent 4ac8ae4 commit 0a9551f

File tree

6 files changed

+280
-1
lines changed

6 files changed

+280
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
4444
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4545
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4646
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
47+
PythonCall = "9d6bc502-a68f-5bec-b839-aff6a2a0a179"
4748

4849
[sources.BracketingNonlinearSolve]
4950
path = "lib/BracketingNonlinearSolve"
@@ -74,6 +75,7 @@ NonlinearSolvePETScExt = ["PETSc", "MPI"]
7475
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
7576
NonlinearSolveSpeedMappingExt = "SpeedMapping"
7677
NonlinearSolveSundialsExt = "Sundials"
78+
NonlinearSolveSciPyExt = "PythonCall"
7779

7880
[compat]
7981
ADTypes = "1.9"
@@ -135,6 +137,7 @@ SymbolicIndexingInterface = "0.3.36"
135137
Test = "1.10"
136138
Zygote = "0.6.69, 0.7"
137139
julia = "1.10"
140+
PythonCall = "0.9"
138141

139142
[extras]
140143
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

ext/NonlinearSolveSciPyExt.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
module NonlinearSolveSciPyExt
2+
3+
# This file is loaded as an extension when PythonCall is available
4+
using PythonCall
5+
const scipy_optimize = try
6+
pyimport("scipy.optimize")
7+
catch err
8+
error("Python package `scipy` could not be imported. Install it in the Python environment used by PythonCall.")
9+
end
10+
11+
using SciMLBase
12+
using NonlinearSolve
13+
14+
# Re-export algorithm type so that `using NonlinearSolve` brings it in when the
15+
# extension is loaded. (Matches convention in other extensions.)
16+
import ..NonlinearSolve: SciPyLeastSquares, SciPyRoot, SciPyRootScalar
17+
using NonlinearSolveBase: construct_extension_function_wrapper
18+
19+
""" Internal: wrap a Julia residual function into a Python callable """
20+
function _make_py_residual(f, p)
21+
return pyfunc(x_py -> begin
22+
x = Vector{Float64}(x_py) # convert NumPy array → Julia Vector
23+
r = f(x, p)
24+
return r # auto-convert back to NumPy
25+
end)
26+
end
27+
28+
""" Internal: wrap a Julia scalar function into a Python callable """
29+
function _make_py_scalar(f, p)
30+
return pyfunc(x_py -> begin
31+
x = Float64(x_py)
32+
return f(x, p)
33+
end)
34+
end
35+
36+
function SciMLBase.__solve(prob::SciMLBase.NonlinearLeastSquaresProblem, alg::SciPyLeastSquares;
37+
abstol = nothing, maxiters = 10_000, alias_u0::Bool = false,
38+
kwargs...)
39+
# Construct Python residual
40+
py_f = _make_py_residual(prob.f, prob.p)
41+
42+
# Bounds handling (lb/ub may be missing)
43+
has_lb = hasproperty(prob, :lb)
44+
has_ub = hasproperty(prob, :ub)
45+
if has_lb || has_ub
46+
lb = has_lb ? getproperty(prob, :lb) : fill(-Inf, length(prob.u0))
47+
ub = has_ub ? getproperty(prob, :ub) : fill( Inf, length(prob.u0))
48+
bounds = (lb, ub)
49+
else
50+
bounds = nothing
51+
end
52+
53+
# Call SciPy
54+
res = scipy_optimize.least_squares(py_f, collect(prob.u0);
55+
method = alg.method,
56+
loss = alg.loss,
57+
max_nfev = maxiters,
58+
bounds = bounds === nothing ? py_none : bounds,
59+
kwargs...)
60+
61+
# Extract solution
62+
u_vec = Vector{Float64}(res.x)
63+
resid = Vector{Float64}(res.fun)
64+
65+
u = prob.u0 isa Number ? u_vec[1] : reshape(u_vec, size(prob.u0))
66+
67+
ret = res.success ? SciMLBase.ReturnCode.Success : SciMLBase.ReturnCode.Failure
68+
njev = try
69+
Int(res.njev)
70+
catch
71+
0
72+
end
73+
stats = SciMLBase.NLStats(res.nfev, njev, 0, 0, res.nfev)
74+
75+
return SciMLBase.build_solution(prob, alg, u, resid; retcode = ret,
76+
original = res, stats = stats)
77+
end
78+
79+
function SciMLBase.__solve(prob::SciMLBase.NonlinearProblem, alg::SciPyRoot;
80+
abstol = nothing, maxiters = 10_000, alias_u0::Bool = false,
81+
kwargs...)
82+
# Get in-place residual wrapper from NonlinearSolveBase.
83+
f!, u0, resid = construct_extension_function_wrapper(prob; alias_u0)
84+
85+
py_f = pyfunc(x_py -> begin
86+
x = Vector{Float64}(x_py)
87+
f!(resid, x)
88+
return resid
89+
end)
90+
91+
tol = abstol === nothing ? nothing : abstol
92+
93+
res = scipy_optimize.root(py_f, collect(u0);
94+
method = alg.method,
95+
tol = tol,
96+
options = Dict("maxiter" => maxiters),
97+
kwargs...)
98+
99+
u_vec = Vector{Float64}(res.x)
100+
f!(resid, u_vec) # update residual
101+
102+
u_out = prob.u0 isa Number ? u_vec[1] : reshape(u_vec, size(prob.u0))
103+
104+
ret = res.success ? SciMLBase.ReturnCode.Success : SciMLBase.ReturnCode.Failure
105+
nfev = try Int(res.nfev) catch; 0 end
106+
niter = try Int(res.nit) catch; 0 end
107+
stats = SciMLBase.NLStats(nfev, 0, 0, 0, niter)
108+
109+
return SciMLBase.build_solution(prob, alg, u_out, resid; retcode = ret,
110+
original = res, stats = stats)
111+
end
112+
113+
function SciMLBase.__solve(prob::SciMLBase.IntervalNonlinearProblem, alg::SciPyRootScalar;
114+
abstol = nothing, maxiters = 10_000, kwargs...)
115+
f = prob.f
116+
p = prob.p
117+
py_f = _make_py_scalar(f, p)
118+
119+
a, b = prob.tspan
120+
121+
res = scipy_optimize.root_scalar(py_f;
122+
method = alg.method,
123+
bracket = (a, b),
124+
maxiter = maxiters,
125+
xtol = abstol,
126+
kwargs...)
127+
128+
u_root = res.root
129+
resid = f(u_root, p)
130+
131+
ret = res.converged ? SciMLBase.ReturnCode.Success : SciMLBase.ReturnCode.Failure
132+
nfev = try Int(res.function_calls) catch; 0 end
133+
niter = try Int(res.iterations) catch; 0 end
134+
stats = SciMLBase.NLStats(nfev, 0, 0, 0, niter)
135+
136+
return SciMLBase.build_solution(prob, alg, u_root, resid; retcode = ret,
137+
original = res, stats = stats)
138+
end
139+
140+
end # module

src/NonlinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ export NonlinearSolvePolyAlgorithm, FastShortcutNonlinearPolyalg, FastShortcutNL
117117

118118
# Extension Algorithms
119119
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
120-
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
120+
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL, SciPyLeastSquares,
121+
SciPyLeastSquaresTRF, SciPyLeastSquaresDogbox, SciPyLeastSquaresLM,
122+
SciPyRoot, SciPyRootScalar
121123
export PETScSNES, CMINPACK
122124

123125
end

src/extension_algs.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,80 @@ function PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing,
488488
end
489489
return PETScSNES(petsclib, mpi_comm, autodiff, kwargs)
490490
end
491+
492+
"""
493+
SciPyLeastSquares(; method = "trf", loss = "linear")
494+
495+
Wrapper over `scipy.optimize.least_squares` (via PythonCall) for solving
496+
`NonlinearLeastSquaresProblem`s. The keyword arguments correspond to the
497+
`method` ("trf", "dogbox", "lm") and the robust loss function ("linear",
498+
"soft_l1", "huber", "cauchy", "arctan").
499+
500+
!!! note
501+
502+
This algorithm is only available if `PythonCall.jl` is installed and the
503+
Python package `scipy` is import-able. Attempting to construct or use the
504+
algorithm otherwise will throw an informative error.
505+
"""
506+
@concrete struct SciPyLeastSquares <: AbstractNonlinearSolveAlgorithm
507+
method::String
508+
loss::String
509+
name::Symbol
510+
end
511+
512+
function SciPyLeastSquares(; method::String = "trf", loss::String = "linear")
513+
if Base.get_extension(@__MODULE__, :NonlinearSolveSciPyExt) === nothing
514+
error("`SciPyLeastSquares` requires `PythonCall.jl` (and SciPy) to be loaded")
515+
end
516+
valid_methods = ("trf", "dogbox", "lm")
517+
valid_losses = ("linear", "soft_l1", "huber", "cauchy", "arctan")
518+
method in valid_methods ||
519+
throw(ArgumentError("Invalid method: $method. Valid methods are: $(join(valid_methods, ", "))"))
520+
loss in valid_losses ||
521+
throw(ArgumentError("Invalid loss: $loss. Valid loss functions are: $(join(valid_losses, ", "))"))
522+
return SciPyLeastSquares(method, loss, :SciPyLeastSquares)
523+
end
524+
525+
SciPyLeastSquaresTRF() = SciPyLeastSquares(method = "trf")
526+
SciPyLeastSquaresDogbox() = SciPyLeastSquares(method = "dogbox")
527+
SciPyLeastSquaresLM() = SciPyLeastSquares(method = "lm")
528+
529+
"""
530+
SciPyRoot(; method = "hybr")
531+
532+
Wrapper over `scipy.optimize.root` for solving `NonlinearProblem`s. Available
533+
methods include "hybr" (default), "lm", "broyden1", "broyden2", "anderson",
534+
"diagbroyden", "linearmixing", "excitingmixing", "krylov", "df-sane" – any
535+
method accepted by SciPy.
536+
"""
537+
@concrete struct SciPyRoot <: AbstractNonlinearSolveAlgorithm
538+
method::String
539+
name::Symbol
540+
end
541+
542+
function SciPyRoot(; method::String = "hybr")
543+
if Base.get_extension(@__MODULE__, :NonlinearSolveSciPyExt) === nothing
544+
error("`SciPyRoot` requires `PythonCall.jl` (and SciPy) to be loaded")
545+
end
546+
return SciPyRoot(method, :SciPyRoot)
547+
end
548+
549+
"""
550+
SciPyRootScalar(; method = "brentq")
551+
552+
Wrapper over `scipy.optimize.root_scalar` for scalar `IntervalNonlinearProblem`s
553+
(bracketing problems). The default method uses Brent's algorithm ( "brentq").
554+
Other valid options: "bisect", "brentq", "brenth", "ridder", "toms748",
555+
"secant", "newton" (secant/Newton do not require brackets).
556+
"""
557+
@concrete struct SciPyRootScalar <: AbstractNonlinearSolveAlgorithm
558+
method::String
559+
name::Symbol
560+
end
561+
562+
function SciPyRootScalar(; method::String = "brentq")
563+
if Base.get_extension(@__MODULE__, :NonlinearSolveSciPyExt) === nothing
564+
error("`SciPyRootScalar` requires `PythonCall.jl` (and SciPy) to be loaded")
565+
end
566+
return SciPyRootScalar(method, :SciPyRootScalar)
567+
end

test/wrappers/least_squares_tests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,30 @@ end
113113
sol = solve(prob_sa, FastLevenbergMarquardtJL())
114114
@test maximum(abs, sol.resid) < 1e-6
115115
end
116+
117+
@testitem "SciPyLeastSquares" setup=[WrapperNLLSSetup] tags=[:wrappers] begin
118+
success = false
119+
try
120+
import PythonCall
121+
spopt = PythonCall.pyimport("scipy.optimize")
122+
success = true
123+
catch
124+
end
125+
if success
126+
xdata = collect(0:0.1:1)
127+
ydata = 2.0 .* xdata .+ 1.0 .+ 0.1 .* randn(length(xdata))
128+
function residuals(params, p=nothing)
129+
a, b = params
130+
return ydata .- (a .* xdata .+ b)
131+
end
132+
x0_ls = [1.0, 0.0]
133+
prob = NonlinearLeastSquaresProblem(residuals, x0_ls)
134+
sol = solve(prob, SciPyLeastSquaresTRF())
135+
@test SciMLBase.successful_retcode(sol)
136+
prob_bounded = NonlinearLeastSquaresProblem(residuals, x0_ls; lb = [0.0,-2.0], ub = [5.0,3.0])
137+
sol2 = solve(prob_bounded, SciPyLeastSquares(method="trf"))
138+
@test SciMLBase.successful_retcode(sol2)
139+
else
140+
@test true # skip: SciPy not present
141+
end
142+
end

test/wrappers/rootfind_tests.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,33 @@ end
185185

186186
@test_throws AssertionError solve(probN, PETScSNES(); abstol = 1e-8)
187187
end
188+
189+
@testitem "SciPyRoot + SciPyRootScalar" tags=[:wrappers] begin
190+
success = false
191+
try
192+
import PythonCall
193+
PythonCall.pyimport("scipy.optimize")
194+
success = true
195+
catch
196+
end
197+
if success
198+
# Vector root example
199+
function fvec(u, p)
200+
return [2 - 2u[1]; u[1] - 4u[2]]
201+
end
202+
u0 = zeros(2)
203+
prob_vec = NonlinearProblem(fvec, u0)
204+
sol_vec = solve(prob_vec, SciPyRoot())
205+
@test SciMLBase.successful_retcode(sol_vec)
206+
@test maximum(abs, sol_vec.resid) < 1e-6
207+
208+
# Scalar bracketing root example
209+
fscalar(x, p) = x^2 - 2
210+
prob_interval = IntervalNonlinearProblem(fscalar, (1.0, 2.0))
211+
sol_scalar = solve(prob_interval, SciPyRootScalar())
212+
@test SciMLBase.successful_retcode(sol_scalar)
213+
@test abs(sol_scalar.u - sqrt(2)) < 1e-6
214+
else
215+
@test true
216+
end
217+
end

0 commit comments

Comments
 (0)