|
| 1 | +export grid_search_tune |
| 2 | + |
| 3 | +# TODO: Issue: For grid_search_tune to work, we need to define `reset!`, but LinearOperators also define reset! |
| 4 | +function reset! end |
| 5 | + |
| 6 | +# TODO: Decide success and costs of grid_search_tune below |
| 7 | + |
| 8 | +""" |
| 9 | + solver, results = grid_search_tune(SolverType, problems; kwargs...) |
| 10 | +
|
| 11 | +Simple tuning of solver `SolverType` by grid search, on `problems`, which should be iterable. |
| 12 | +The following keyword arguments are available: |
| 13 | +- `success`: A function to be applied on a solver output that returns whether the problem has terminated succesfully. Defaults to `o -> o.status == :first_order`. |
| 14 | +- `costs`: A vector of cost functions and penalties. Each element is a tuple of two elements. The first is a function to be applied to the output of the solver, and the second is the cost when the solver fails (see `success` above) or throws an error. Defaults to |
| 15 | +``` |
| 16 | +[ |
| 17 | + (o -> o.elapsed_time, 100.0), |
| 18 | + (o -> o.counters.neval_obj + o.counters.neval_cons, 1000), |
| 19 | + (o -> !success(o), 1), |
| 20 | +] |
| 21 | +``` |
| 22 | +which represent the total elapsed_time (with a penalty of 100.0 for failures); the number of objective and constraints functions evaluations (with a penalty of 1000 for failures); and the number of failures. |
| 23 | +- `grid_length`: The number of points in the ranges of the grid for continuous points. |
| 24 | +- `solver_kwargs`: Arguments to be passed to the solver. Note: use this to set the stopping parameters, but not the other parameters being optimize. |
| 25 | +- Any parameters accepted by the `Solver`: a range to be used instead of the default range. |
| 26 | +
|
| 27 | +The default ranges are based on the parameters types, and are as follows: |
| 28 | +- `:real`: linear range from `:min` to `:max` with `grid_length` points. |
| 29 | +- `:log`: logarithmic range from `:min` to `:max` with `grid_length` points. Computed by exp of linear range of `log(:min)` to `log(:max)`. |
| 30 | +- `:bool`: either `false` or `true`. |
| 31 | +- `:int`: integer range from `:min` to `:max`. |
| 32 | +""" |
| 33 | +function grid_search_tune( |
| 34 | + ::Type{Solver}, |
| 35 | + problems; |
| 36 | + success = o -> o.status == :first_order, |
| 37 | + costs = [(o -> o.elapsed_time, 100.0), (o -> !success(o), 1)], |
| 38 | + grid_length = 10, |
| 39 | + solver_kwargs = Dict(), |
| 40 | + kwargs..., |
| 41 | +) where {Solver <: AbstractSolver} |
| 42 | + solver_params = parameters(Solver) |
| 43 | + params = OrderedDict() |
| 44 | + for (k, v) in pairs(solver_params) |
| 45 | + if v[:type] <: AbstractFloat && (!haskey(v, :scale) || v[:scale] == :linear) |
| 46 | + params[k] = LinRange(v[:min], v[:max], grid_length) |
| 47 | + elseif v[:type] <: AbstractFloat && v[:scale] == :log |
| 48 | + params[k] = exp.(LinRange(log(v[:min]), log(v[:max]), grid_length)) |
| 49 | + elseif v[:type] == Bool |
| 50 | + params[k] = (false, true) |
| 51 | + elseif v[:type] <: Integer |
| 52 | + params[k] = v[:min]:v[:max] |
| 53 | + end |
| 54 | + end |
| 55 | + for (k, v) in kwargs |
| 56 | + params[k] = v |
| 57 | + end |
| 58 | + |
| 59 | + # Precompiling |
| 60 | + problem = first(problems) |
| 61 | + try |
| 62 | + solver = Solver(problem) |
| 63 | + output = with_logger(NullLogger()) do |
| 64 | + solve!(solver, problem) |
| 65 | + end |
| 66 | + finally |
| 67 | + finalize(problem) |
| 68 | + end |
| 69 | + |
| 70 | + cost(θ) = begin |
| 71 | + total_cost = [zero(x[2]) for x in costs] |
| 72 | + for problem in problems |
| 73 | + reset!(problem) |
| 74 | + try |
| 75 | + solver = Solver(problem) |
| 76 | + P = (k => θi for (k, θi) in zip(keys(solver_params), θ)) |
| 77 | + output = with_logger(NullLogger()) do |
| 78 | + solve!(solver, problem; P...) |
| 79 | + end |
| 80 | + for (i, c) in enumerate(costs) |
| 81 | + if success(output) |
| 82 | + total_cost[i] += (c[1])(output) |
| 83 | + else |
| 84 | + total_cost[i] += c[2] |
| 85 | + end |
| 86 | + end |
| 87 | + catch ex |
| 88 | + for (i, c) in enumerate(costs) |
| 89 | + total_cost[i] += c[2] |
| 90 | + end |
| 91 | + @error ex |
| 92 | + finally |
| 93 | + finalize(problem) |
| 94 | + end |
| 95 | + end |
| 96 | + total_cost |
| 97 | + end |
| 98 | + |
| 99 | + [θ => cost(θ) for θ in Iterators.product(values(params)...)] |
| 100 | +end |
0 commit comments