Skip to content

Commit 6b0fcbf

Browse files
committed
added OptimizationPyCMA.jl
1 parent ff4b0a8 commit 6b0fcbf

File tree

4 files changed

+187
-0
lines changed

4 files changed

+187
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
matplotlib = ""
3+
cma = ""

lib/OptimizationPyCMA/Project.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name = "OptimizationPyCMA"
2+
uuid = "fb0822aa-1fe5-41d8-99a6-e7bf6c238d3b"
3+
authors = ["Maximilian Pochapski <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
8+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9+
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
10+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12+
13+
[compat]
14+
CondaPkg = "0.2.29"
15+
Optimization = "4.4.0"
16+
PythonCall = "0.9.25"
17+
Reexport = "1.2.2"
18+
Test = "1.11.0"
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
module OptimizationPyCMA
2+
3+
using Reexport
4+
@reexport using Optimization
5+
using PythonCall, Optimization.SciMLBase
6+
7+
export PyCMAOpt
8+
9+
struct PyCMAOpt end
10+
11+
# importing PyCMA
12+
const cma = Ref{Py}()
13+
function get_cma()
14+
if !isassigned(cma) || cma[] === nothing
15+
cma[] = pyimport("cma")
16+
end
17+
return cma[]
18+
end
19+
20+
# Defining the SciMLBase interface for PyCMAOpt
21+
22+
SciMLBase.allowsbounds(::PyCMAOpt) = true
23+
SciMLBase.supports_opt_cache_interface(opt::PyCMAOpt) = true
24+
SciMLBase.requiresgradient(::PyCMAOpt) = false
25+
SciMLBase.requireshessian(::PyCMAOpt) = false
26+
SciMLBase.requiresconsjac(::PyCMAOpt) = false
27+
SciMLBase.requiresconshess(::PyCMAOpt) = false
28+
29+
# wrapping Optimization.jl args into a python dict as arguments to PyCMA opts
30+
function __map_optimizer_args(prob::OptimizationCache, opt::PyCMAOpt;
31+
maxiters::Union{Number, Nothing} = nothing,
32+
maxtime::Union{Number, Nothing} = nothing,
33+
abstol::Union{Number, Nothing} = nothing,
34+
reltol::Union{Number, Nothing} = nothing)
35+
if !isnothing(reltol)
36+
@warn "common reltol is currently not used by $(opt)"
37+
end
38+
39+
mapped_args = Dict(
40+
"verbose" => -5,
41+
"bounds" => (prob.lb, prob.ub),
42+
)
43+
44+
if !isnothing(abstol)
45+
mapped_args["tolfun"] = abstol
46+
end
47+
48+
if !isnothing(reltol)
49+
mapped_args["tolfunrel"] = reltol
50+
end
51+
52+
if !isnothing(maxtime)
53+
mapped_args["timeout"] = maxtime
54+
end
55+
56+
if !isnothing(maxiters)
57+
mapped_args["maxiter"] = maxiters
58+
end
59+
60+
return mapped_args
61+
end
62+
63+
function __map_pycma_retcode(stop_dict::Dict{String, Any})
64+
# mapping termination conditions to SciMLBase return codes
65+
if any(k keys(stop_dict) for k in ["ftarget", "tolfun", "tolx"])
66+
return ReturnCode.Success
67+
elseif any(k keys(stop_dict) for k in ["maxiter", "maxfevals"])
68+
return ReturnCode.MaxIters
69+
elseif "timeout" keys(stop_dict)
70+
return ReturnCode.MaxTime
71+
elseif "callback" keys(stop_dict)
72+
return ReturnCode.Terminated
73+
elseif any(k keys(stop_dict) for k in ["tolupsigma", "tolconditioncov", "noeffectcoord", "noeffectaxis", "tolxstagnation", "tolflatfitness", "tolfacupx", "tolstagnation"])
74+
return ReturnCode.Failure
75+
else
76+
return ReturnCode.Default
77+
end
78+
end
79+
80+
function SciMLBase.__solve(cache::OptimizationCache{
81+
F,
82+
RC,
83+
LB,
84+
UB,
85+
LC,
86+
UC,
87+
S,
88+
O,
89+
D,
90+
P,
91+
C
92+
}) where {
93+
F,
94+
RC,
95+
LB,
96+
UB,
97+
LC,
98+
UC,
99+
S,
100+
O <:
101+
PyCMAOpt,
102+
D,
103+
P,
104+
C
105+
}
106+
local x
107+
108+
# doing conversions
109+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
110+
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
111+
112+
# wrapping the objective function
113+
_loss = function (θ)
114+
x = cache.f(θ, cache.p)
115+
return first(x)
116+
end
117+
118+
# converting the Optimization.jl Args to PyCMA format
119+
opt_args = __map_optimizer_args(cache, cache.opt; cache.solver_args...,
120+
maxiters = maxiters,
121+
maxtime = maxtime)
122+
123+
# init the CMAopt class
124+
es = get_cma().CMAEvolutionStrategy(cache.u0, 1, pydict(opt_args))
125+
logger = es.logger
126+
127+
# running the optimization
128+
t0 = time()
129+
opt_res = es.optimize(_loss)
130+
t1 = time()
131+
132+
# loading logged files from disk
133+
logger.load()
134+
135+
# reading the results
136+
opt_ret_dict = opt_res.stop()
137+
retcode = __map_pycma_retcode(pyconvert(Dict{String, Any}, opt_ret_dict))
138+
139+
# logging and returning results of the optimization
140+
stats = Optimization.OptimizationStats(;
141+
iterations = length(logger.xmean),
142+
time = t1 - t0,
143+
fevals = length(logger.xmean))
144+
145+
SciMLBase.build_solution(cache, cache.opt,
146+
pyconvert(Float64, logger.xrecent[-1][-1]),
147+
pyconvert(Float64, logger.f[-1][-1]); original = opt_res,
148+
retcode = retcode,
149+
stats = stats)
150+
end
151+
152+
end # module OptimizationPyCMA
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using OptimizationPyCMA
2+
using Test
3+
4+
@testset "OptimizationPyCMA.jl" begin
5+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
6+
x0 = zeros(2)
7+
_p = [1.0, 100.0]
8+
l1 = rosenbrock(x0, _p)
9+
f = OptimizationFunction(rosenbrock)
10+
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
11+
sol = solve(prob, PyCMAOpt())
12+
@test 10 * sol.objective < l1
13+
sol = solve(prob, PyCMAOpt(), maxiters = 100)
14+
end

0 commit comments

Comments
 (0)