Skip to content

Commit 4653ee5

Browse files
committed
Solver revamp src folder
1 parent 57934d1 commit 4653ee5

File tree

8 files changed

+239
-220
lines changed

8 files changed

+239
-220
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ uuid = "ff4d7338-4cf1-434d-91df-b86cb86fb843"
33
version = "0.1.0"
44

55
[deps]
6-
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
6+
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
7+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
78
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
89

910
[compat]
10-
NLPModels = "0.14"
1111
julia = "^1.3.0"
1212

1313
[extras]
14-
ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1818

1919
[targets]
20-
test = ["ADNLPModels", "LinearAlgebra", "Logging", "Test"]
20+
test = ["LinearAlgebra", "Logging", "Random", "Test"]

src/SolverCore.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
module SolverCore
22

33
# stdlib
4-
using Printf
4+
using Logging, Printf
5+
using OrderedCollections
56

6-
# our packages
7-
using NLPModels
7+
include("solver.jl")
8+
include("output.jl")
89

910
include("logger.jl")
10-
include("stats.jl")
11+
include("parameters.jl")
12+
include("traits.jl")
13+
14+
include("grid-search-tuning.jl")
1115

1216
end

src/grid-search-tuning.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
export grid_search_tune
2+
3+
# TODO: Issue: For grid_search_tune to work, we need to define `reset!`, but LinearOperators also define reset!
4+
function reset! end
5+
6+
# TODO: Decide success and costs of grid_search_tune below
7+
8+
"""
9+
solver, results = grid_search_tune(SolverType, problems; kwargs...)
10+
11+
Simple tuning of solver `SolverType` by grid search, on `problems`, which should be iterable.
12+
The following keyword arguments are available:
13+
- `success`: A function to be applied on a solver output that returns whether the problem has terminated succesfully. Defaults to `o -> o.status == :first_order`.
14+
- `costs`: A vector of cost functions and penalties. Each element is a tuple of two elements. The first is a function to be applied to the output of the solver, and the second is the cost when the solver fails (see `success` above) or throws an error. Defaults to
15+
```
16+
[
17+
(o -> o.elapsed_time, 100.0),
18+
(o -> o.counters.neval_obj + o.counters.neval_cons, 1000),
19+
(o -> !success(o), 1),
20+
]
21+
```
22+
which represent the total elapsed_time (with a penalty of 100.0 for failures); the number of objective and constraints functions evaluations (with a penalty of 1000 for failures); and the number of failures.
23+
- `grid_length`: The number of points in the ranges of the grid for continuous points.
24+
- `solver_kwargs`: Arguments to be passed to the solver. Note: use this to set the stopping parameters, but not the other parameters being optimize.
25+
- Any parameters accepted by the `Solver`: a range to be used instead of the default range.
26+
27+
The default ranges are based on the parameters types, and are as follows:
28+
- `:real`: linear range from `:min` to `:max` with `grid_length` points.
29+
- `:log`: logarithmic range from `:min` to `:max` with `grid_length` points. Computed by exp of linear range of `log(:min)` to `log(:max)`.
30+
- `:bool`: either `false` or `true`.
31+
- `:int`: integer range from `:min` to `:max`.
32+
"""
33+
function grid_search_tune(
34+
::Type{Solver},
35+
problems;
36+
success = o -> o.status == :first_order,
37+
costs = [(o -> o.elapsed_time, 100.0), (o -> !success(o), 1)],
38+
grid_length = 10,
39+
solver_kwargs = Dict(),
40+
kwargs...,
41+
) where {Solver <: AbstractSolver}
42+
solver_params = parameters(Solver)
43+
params = OrderedDict()
44+
for (k, v) in pairs(solver_params)
45+
if v[:type] <: AbstractFloat && (!haskey(v, :scale) || v[:scale] == :linear)
46+
params[k] = LinRange(v[:min], v[:max], grid_length)
47+
elseif v[:type] <: AbstractFloat && v[:scale] == :log
48+
params[k] = exp.(LinRange(log(v[:min]), log(v[:max]), grid_length))
49+
elseif v[:type] == Bool
50+
params[k] = (false, true)
51+
elseif v[:type] <: Integer
52+
params[k] = v[:min]:v[:max]
53+
end
54+
end
55+
for (k, v) in kwargs
56+
params[k] = v
57+
end
58+
59+
# Precompiling
60+
problem = first(problems)
61+
try
62+
solver = Solver(problem)
63+
output = with_logger(NullLogger()) do
64+
solve!(solver, problem)
65+
end
66+
finally
67+
finalize(problem)
68+
end
69+
70+
cost(θ) = begin
71+
total_cost = [zero(x[2]) for x in costs]
72+
for problem in problems
73+
reset!(problem)
74+
try
75+
solver = Solver(problem)
76+
P = (k => θi for (k, θi) in zip(keys(solver_params), θ))
77+
output = with_logger(NullLogger()) do
78+
solve!(solver, problem; P...)
79+
end
80+
for (i, c) in enumerate(costs)
81+
if success(output)
82+
total_cost[i] += (c[1])(output)
83+
else
84+
total_cost[i] += c[2]
85+
end
86+
end
87+
catch ex
88+
for (i, c) in enumerate(costs)
89+
total_cost[i] += c[2]
90+
end
91+
@error ex
92+
finally
93+
finalize(problem)
94+
end
95+
end
96+
total_cost
97+
end
98+
99+
=> cost(θ) for θ in Iterators.product(values(params)...)]
100+
end

src/output.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
export AbstractSolverOutput
2+
3+
# TODO: Define the required fields and API for all Outputs
4+
"""
5+
AbstractSolverOutput{T}
6+
7+
Base type for output of JSO-compliant solvers.
8+
An output must have at least the following:
9+
- `status :: Symbol`
10+
- `solution`
11+
"""
12+
abstract type AbstractSolverOutput{T} end
13+
14+
# TODO: Decision: Should STATUSES be fixed? Should it be all here?
15+
const STATUSES = Dict(
16+
:exception => "unhandled exception",
17+
:first_order => "first-order stationary",
18+
:acceptable => "solved to within acceptable tolerances",
19+
:infeasible => "problem may be infeasible",
20+
:max_eval => "maximum number of function evaluations",
21+
:max_iter => "maximum iteration",
22+
:max_time => "maximum elapsed time",
23+
:neg_pred => "negative predicted reduction",
24+
:not_desc => "not a descent direction",
25+
:small_residual => "small residual",
26+
:small_step => "step too small",
27+
:stalled => "stalled",
28+
:unbounded => "objective function may be unbounded from below",
29+
:unknown => "unknown",
30+
:user => "user-requested stop",
31+
)
32+
33+
function Base.show(io::IO, output::AbstractSolverOutput)
34+
println(io, "Solver output of type $(typeof(output))")
35+
println(io, "Status: $(STATUSES[output.status])")
36+
end

src/parameters.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
export parameters, are_valid_parameters
2+
3+
"""
4+
named_tuple = parameters(solver)
5+
named_tuple = parameters(SolverType)
6+
named_tuple = parameters(SolverType{T})
7+
8+
Return the parameters of a `solver`, or of the type `SolverType`.
9+
You can specify the type `T` of the `SolverType`.
10+
The returned structure is a nested NamedTuple.
11+
Each key of `named_tuple` is the name of a parameter, and its value is a NamedTuple containing
12+
- `default`: The default value of the parameter.
13+
- `type`: The type of the parameter, such as `Int`, `Float64`, `T`, etc.
14+
15+
and possibly other values depending on the `type`.
16+
Some possibilies are:
17+
18+
- `scale`: How to explore the domain
19+
- `:linear`: A continuous value within a range
20+
- `:log`: A positive continuous value that should be explored logarithmically (like 10⁻², 10⁻¹, 1, 10).
21+
- `min`: Minimum value.
22+
- `max`: Maximum value.
23+
24+
Solvers should define
25+
26+
SolverCore.parameters(::Type{Solver{T}}) where T
27+
"""
28+
function parameters end
29+
30+
parameters(::Type{S}) where {S <: AbstractSolver} = parameters(S{Float64})
31+
parameters(::S) where {S <: AbstractSolver} = parameters(S)
32+
33+
"""
34+
are_valid_parameters(solver, args...)
35+
36+
Return whether the parameters given in `args` are valid for `solver`.
37+
The order of the parameters must be the same as in `parameters(solver)`.
38+
39+
Solvers should define
40+
41+
SolverCore.are_valid_parameters(::Type{Solver{T}}, arg1, arg2, ...) where T
42+
"""
43+
function are_valid_parameters end
44+
are_valid_parameters(::Type{S}, args...) where {S <: AbstractSolver} =
45+
are_valid_parameters(S{Float64}, args...)
46+
are_valid_parameters(::S, args...) where {S <: AbstractSolver} = are_valid_parameters(S, args...)

src/solver.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
export AbstractSolver, solve!
2+
3+
# TODO: Define the required fields and API for all Solvers
4+
"""
5+
AbstractSolver
6+
7+
Base type for JSO-compliant solvers.
8+
A solver must have three members:
9+
- `initialized :: Bool`, indicating whether the solver was initialized
10+
- `params :: Dict`, a dictionary of parameters for the solver
11+
- `workspace`, a named tuple with arrays used by the solver.
12+
"""
13+
abstract type AbstractSolver{T} end
14+
15+
function Base.show(io::IO, solver::AbstractSolver)
16+
println(io, "Solver $(typeof(solver))")
17+
end
18+
19+
"""
20+
output = solve!(solver, problem)
21+
22+
Solve `problem` with `solver`.
23+
"""
24+
function solve! end
25+
26+
# TODO: Define general constructors that automatically call `solve!`, etc.?

0 commit comments

Comments
 (0)