Skip to content

Commit c65d083

Browse files
removed unnecessary try/catch and some refactoring
Signed-off-by: AdityaPandeyCN <[email protected]>
1 parent 319abc4 commit c65d083

File tree

2 files changed

+95
-71
lines changed

2 files changed

+95
-71
lines changed

lib/OptimizationSciPy/src/OptimizationSciPy.jl

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,31 @@ function ensure_julia_array(x, ::Type{T}=Float64) where T
1717
end
1818

1919
function safe_get_message(result)
20-
try
21-
pyconvert(String, result.message)
22-
catch
23-
"Optimization completed"
20+
pyhasattr(result, "message") || return "Optimization completed"
21+
msg = result.message
22+
if pyisinstance(msg, pybuiltins.str)
23+
return pyconvert(String, msg)
24+
end
25+
if pyisinstance(msg, pybuiltins.list) || pyisinstance(msg, pybuiltins.tuple)
26+
return join(pyconvert(Vector{String}, msg), ", ")
2427
end
28+
return string(pytypeof(msg))
2529
end
2630

2731
function safe_to_float(x)
2832
x isa Float64 && return x
29-
try
30-
return pyconvert(Float64, x)
31-
catch
32-
33+
x isa Number && return Float64(x)
34+
35+
if x isa Py
3336
if pyhasattr(x, "item")
34-
try
35-
return pyconvert(Float64, x.item())
36-
catch
37-
end
38-
end
39-
try
40-
return pyconvert(Float64, pybuiltins.float(x))
41-
catch
42-
error("Cannot convert Python object to Float64: $(typeof(x))")
37+
v = pyconvert(Float64, x.item(), nothing)
38+
v !== nothing && return v
4339
end
40+
v = pyconvert(Float64, x, nothing)
41+
v !== nothing && return v
4442
end
43+
44+
error("Cannot convert object to Float64: $(typeof(x))")
4545
end
4646

4747
function extract_stats(result, time_elapsed)
@@ -52,34 +52,19 @@ function extract_stats(result, time_elapsed)
5252
:gevals => 0,
5353
:hevals => 0
5454
)
55-
if pyhasattr(result, "nit")
56-
try
57-
stats_dict[:iterations] = pyconvert(Int, result.nit)
58-
catch
59-
end
55+
if pyhasattr(result, "nit") && !pyis(result.nit, pybuiltins.None)
56+
stats_dict[:iterations] = pyconvert(Int, result.nit)
6057
end
61-
if pyhasattr(result, "nfev")
62-
try
63-
stats_dict[:fevals] = pyconvert(Int, result.nfev)
64-
catch
65-
end
58+
if pyhasattr(result, "nfev") && !pyis(result.nfev, pybuiltins.None)
59+
stats_dict[:fevals] = pyconvert(Int, result.nfev)
6660
end
67-
if pyhasattr(result, "njev")
68-
try
69-
stats_dict[:gevals] = pyconvert(Int, result.njev)
70-
catch
71-
end
72-
elseif pyhasattr(result, "ngrad")
73-
try
74-
stats_dict[:gevals] = pyconvert(Int, result.ngrad)
75-
catch
76-
end
61+
if pyhasattr(result, "njev") && !pyis(result.njev, pybuiltins.None)
62+
stats_dict[:gevals] = pyconvert(Int, result.njev)
63+
elseif pyhasattr(result, "ngrad") && !pyis(result.ngrad, pybuiltins.None)
64+
stats_dict[:gevals] = pyconvert(Int, result.ngrad)
7765
end
78-
if pyhasattr(result, "nhev")
79-
try
80-
stats_dict[:hevals] = pyconvert(Int, result.nhev)
81-
catch
82-
end
66+
if pyhasattr(result, "nhev") && !pyis(result.nhev, pybuiltins.None)
67+
stats_dict[:hevals] = pyconvert(Int, result.nhev)
8368
end
8469
return Optimization.OptimizationStats(; stats_dict...)
8570
end
@@ -105,7 +90,9 @@ function scipy_status_to_retcode(status::Int, success::Bool)
10590
end
10691
end
10792

108-
struct ScipyMinimize
93+
abstract type ScipyOptimizer end
94+
95+
struct ScipyMinimize <: ScipyOptimizer
10996
method::String
11097
function ScipyMinimize(method::String)
11198
valid_methods = ["Nelder-Mead", "Powell", "CG", "BFGS", "Newton-CG",
@@ -136,7 +123,7 @@ ScipyTrustNCG() = ScipyMinimize("trust-ncg")
136123
ScipyTrustKrylov() = ScipyMinimize("trust-krylov")
137124
ScipyTrustExact() = ScipyMinimize("trust-exact")
138125

139-
struct ScipyMinimizeScalar
126+
struct ScipyMinimizeScalar <: ScipyOptimizer
140127
method::String
141128
function ScipyMinimizeScalar(method::String="brent")
142129
valid_methods = ["brent", "bounded", "golden"]
@@ -151,7 +138,7 @@ ScipyBrent() = ScipyMinimizeScalar("brent")
151138
ScipyBounded() = ScipyMinimizeScalar("bounded")
152139
ScipyGolden() = ScipyMinimizeScalar("golden")
153140

154-
struct ScipyLeastSquares
141+
struct ScipyLeastSquares <: ScipyOptimizer
155142
method::String
156143
loss::String
157144
function ScipyLeastSquares(; method::String="trf", loss::String="linear")
@@ -171,7 +158,7 @@ ScipyLeastSquaresTRF() = ScipyLeastSquares(method="trf")
171158
ScipyLeastSquaresDogbox() = ScipyLeastSquares(method="dogbox")
172159
ScipyLeastSquaresLM() = ScipyLeastSquares(method="lm")
173160

174-
struct ScipyRootScalar
161+
struct ScipyRootScalar <: ScipyOptimizer
175162
method::String
176163
function ScipyRootScalar(method::String="brentq")
177164
valid_methods = ["brentq", "brenth", "bisect", "ridder", "newton", "secant", "halley", "toms748"]
@@ -182,7 +169,7 @@ struct ScipyRootScalar
182169
end
183170
end
184171

185-
struct ScipyRoot
172+
struct ScipyRoot <: ScipyOptimizer
186173
method::String
187174
function ScipyRoot(method::String="hybr")
188175
valid_methods = ["hybr", "lm", "broyden1", "broyden2", "anderson",
@@ -195,7 +182,7 @@ struct ScipyRoot
195182
end
196183
end
197184

198-
struct ScipyLinprog
185+
struct ScipyLinprog <: ScipyOptimizer
199186
method::String
200187
function ScipyLinprog(method::String="highs")
201188
valid_methods = ["highs", "highs-ds", "highs-ipm", "interior-point",
@@ -207,13 +194,13 @@ struct ScipyLinprog
207194
end
208195
end
209196

210-
struct ScipyMilp end
211-
struct ScipyDifferentialEvolution end
212-
struct ScipyBasinhopping end
213-
struct ScipyDualAnnealing end
214-
struct ScipyShgo end
215-
struct ScipyDirect end
216-
struct ScipyBrute end
197+
struct ScipyMilp <: ScipyOptimizer end
198+
struct ScipyDifferentialEvolution <: ScipyOptimizer end
199+
struct ScipyBasinhopping <: ScipyOptimizer end
200+
struct ScipyDualAnnealing <: ScipyOptimizer end
201+
struct ScipyShgo <: ScipyOptimizer end
202+
struct ScipyDirect <: ScipyOptimizer end
203+
struct ScipyBrute <: ScipyOptimizer end
217204

218205
for opt_type in [:ScipyMinimize, :ScipyDifferentialEvolution, :ScipyBasinhopping,
219206
:ScipyDualAnnealing, :ScipyShgo, :ScipyDirect, :ScipyBrute,
@@ -279,10 +266,7 @@ end
279266

280267
SciMLBase.allowsbounds(::ScipyRoot) = false
281268

282-
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::Union{ScipyMinimize,
283-
ScipyDifferentialEvolution, ScipyBasinhopping, ScipyDualAnnealing,
284-
ScipyShgo, ScipyDirect, ScipyBrute, ScipyMinimizeScalar,
285-
ScipyLeastSquares, ScipyRootScalar, ScipyRoot, ScipyLinprog, ScipyMilp};
269+
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::ScipyOptimizer;
286270
cons_tol = 1e-6,
287271
callback = (args...) -> (false),
288272
progress = false,
@@ -499,7 +483,7 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C})
499483
if cache.callback(opt_state, x...)
500484
error("Optimization halted by callback")
501485
end
502-
return cache.sense === Optimization.MaxSense ? -x[1] : x[1]
486+
return x[1]
503487
end
504488
kwargs = Dict{Symbol, Any}()
505489
if cache.opt.method == "bounded"
@@ -550,7 +534,16 @@ end
550534

551535
function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}) where
552536
{F,RC,LB,UB,LC,UC,S,O<:ScipyLeastSquares,D,P,C}
553-
_residuals = _create_loss(cache; vector_output=true)
537+
_residuals = nothing
538+
if hasfield(typeof(cache.f), :f) && (cache.f.f isa ResidualObjective)
539+
real_res = (cache.f.f)::ResidualObjective
540+
_residuals = function(θ)
541+
θ_julia = ensure_julia_array(θ, eltype(cache.u0))
542+
return real_res.residual(θ_julia, cache.p)
543+
end
544+
else
545+
_residuals = _create_loss(cache; vector_output=true)
546+
end
554547
kwargs = Dict{Symbol, Any}()
555548
kwargs[:method] = cache.opt.method
556549
kwargs[:loss] = cache.opt.loss
@@ -1318,6 +1311,30 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C})
13181311
stats = stats)
13191312
end
13201313

1314+
function SciMLBase.__init(prob::SciMLBase.NonlinearLeastSquaresProblem, opt::ScipyLeastSquares; kwargs...)
1315+
obj = ResidualObjective(prob.f)
1316+
optf = Optimization.OptimizationFunction(obj)
1317+
1318+
has_lb = hasproperty(prob, :lb)
1319+
has_ub = hasproperty(prob, :ub)
1320+
1321+
if has_lb || has_ub
1322+
lb_val = has_lb ? getproperty(prob, :lb) : fill(-Inf, length(prob.u0))
1323+
ub_val = has_ub ? getproperty(prob, :ub) : fill( Inf, length(prob.u0))
1324+
optprob = Optimization.OptimizationProblem(optf, prob.u0, prob.p;
1325+
lb = lb_val, ub = ub_val,
1326+
sense = Optimization.MinSense)
1327+
else
1328+
optprob = Optimization.OptimizationProblem(optf, prob.u0, prob.p;
1329+
sense = Optimization.MinSense)
1330+
end
1331+
1332+
return SciMLBase.__init(optprob, opt; kwargs...)
1333+
end
1334+
1335+
function SciMLBase.init(prob::SciMLBase.NonlinearLeastSquaresProblem, opt::ScipyLeastSquares; kwargs...)
1336+
SciMLBase.__init(prob, opt; kwargs...)
1337+
end
13211338

13221339
export ScipyMinimize, ScipyNelderMead, ScipyPowell, ScipyCG, ScipyBFGS, ScipyNewtonCG,
13231340
ScipyLBFGSB, ScipyTNC, ScipyCOBYLA, ScipyCOBYQA, ScipySLSQP, ScipyTrustConstr,
@@ -1344,7 +1361,9 @@ function _create_loss(cache; vector_output::Bool = false)
13441361
if cache.callback(opt_state, x...)
13451362
error("Optimization halted by callback")
13461363
end
1347-
return cache.sense === Optimization.MaxSense ? -x : x
1364+
1365+
arr = cache.sense === Optimization.MaxSense ? -x : x
1366+
return arr
13481367
end
13491368
else
13501369
return function (θ)
@@ -1378,4 +1397,11 @@ function _build_bounds(lb::AbstractVector, ub::AbstractVector)
13781397
return pylist([pytuple([lb[i], ub[i]]) for i in eachindex(lb)])
13791398
end
13801399

1381-
end
1400+
struct ResidualObjective{R}
1401+
residual::R
1402+
end
1403+
1404+
(r::ResidualObjective)(u, p) = sum(abs2, r.residual(u, p))
1405+
1406+
end
1407+

lib/OptimizationSciPy/test/runtests.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using OptimizationSciPy, Optimization, Zygote, ReverseDiff, ForwardDiff
22
using Test, Random
3-
using Optimization.SciMLBase: ReturnCode
3+
using Optimization.SciMLBase: ReturnCode, NonlinearLeastSquaresProblem
44
using PythonCall
55

66
function rosenbrock(x, p)
@@ -166,13 +166,12 @@ end
166166
@testset "ScipyLeastSquares" begin
167167
xdata = collect(0:0.1:1)
168168
ydata = 2.0 * xdata .+ 1.0 .+ 0.1 * randn(length(xdata))
169-
function residuals(params, p)
169+
function residuals(params, p=nothing)
170170
a, b = params
171171
return ydata .- (a .* xdata .+ b)
172172
end
173173
x0_ls = [1.0, 0.0]
174-
optf = OptimizationFunction(residuals)
175-
prob = OptimizationProblem(optf, x0_ls)
174+
prob = NonlinearLeastSquaresProblem(residuals, x0_ls)
176175
sol = solve(prob, ScipyLeastSquaresTRF())
177176
@test sol.retcode == ReturnCode.Success
178177
@test sol.u[1] 2.0 atol=0.5
@@ -183,7 +182,7 @@ end
183182
sol = solve(prob, ScipyLeastSquaresLM())
184183
@test sol.retcode == ReturnCode.Success
185184
@test sol.u[1] 2.0 atol=0.5
186-
prob_bounded = OptimizationProblem(optf, x0_ls, nothing, lb = [0.0, -2.0], ub = [5.0, 3.0])
185+
prob_bounded = NonlinearLeastSquaresProblem(residuals, x0_ls; lb = [0.0, -2.0], ub = [5.0, 3.0])
187186
sol = solve(prob_bounded, ScipyLeastSquaresTRF())
188187
@test sol.retcode == ReturnCode.Success
189188
@test 0.0 <= sol.u[1] <= 5.0
@@ -194,12 +193,11 @@ end
194193
end
195194
ydata_outliers = copy(ydata)
196195
ydata_outliers[5] = 10.0
197-
function residuals_outliers(params, p)
196+
function residuals_outliers(params, p=nothing)
198197
a, b = params
199198
return ydata_outliers .- (a .* xdata .+ b)
200199
end
201-
optf_outliers = OptimizationFunction(residuals_outliers)
202-
prob_outliers = OptimizationProblem(optf_outliers, x0_ls)
200+
prob_outliers = NonlinearLeastSquaresProblem(residuals_outliers, x0_ls)
203201
sol_robust = solve(prob_outliers, ScipyLeastSquares(method="trf", loss="huber"))
204202
@test sol_robust.retcode == ReturnCode.Success
205203
end

0 commit comments

Comments
 (0)