|
5 | 5 | result[i] = sol.u |
6 | 6 | end |
7 | 7 |
|
8 | | -function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, |
9 | | - opt::HybridPSO{Backend, LocalOpt}, |
10 | | - args...; |
11 | | - abstol = nothing, |
12 | | - reltol = nothing, |
13 | | - maxiters = 100, |
14 | | - local_maxiters = 10, |
| 8 | +struct HybridPSOCache{TPc, TSp, TAlg} |
| 9 | + pso_cache::TPc |
| 10 | + start_points::TSp |
| 11 | + alg::TAlg |
| 12 | +end |
| 13 | + |
| 14 | +function SciMLBase.init( |
| 15 | + prob::OptimizationProblem, opt::HybridPSO{Backend, LocalOpt}, args...; |
15 | 16 | kwargs...) where {Backend, LocalOpt <: Union{LBFGS, BFGS}} |
16 | | - t0 = time() |
17 | 17 | psoalg = opt.pso |
18 | | - local_opt = opt.local_opt |
19 | 18 | backend = opt.backend |
20 | 19 |
|
21 | | - sol_pso = solve(prob, psoalg, args...; maxiters, kwargs...) |
| 20 | + pso_cache = init(prob, psoalg) |
| 21 | + |
| 22 | + start_points = KernelAbstractions.allocate( |
| 23 | + backend, typeof(prob.u0), opt.pso.num_particles) |
| 24 | + |
| 25 | + return HybridPSOCache{ |
| 26 | + typeof(pso_cache), typeof(start_points), typeof(opt)}(pso_cache, start_points, opt) |
| 27 | +end |
| 28 | + |
| 29 | +function reinit_cache!(cache::HybridPSOCache, |
| 30 | + opt::HybridPSO{Backend, LocalOpt}) where {Backend, LocalOpt <: Union{LBFGS, BFGS}} |
| 31 | + reinit!(cache.pso_cache) |
| 32 | + fill!(cache.start_points, zero(eltype(cache.start_points))) |
| 33 | + # prob = cache.prob |
| 34 | + # backend = opt.backend |
| 35 | + # particles = cache.particles |
| 36 | + |
| 37 | + # kernel! = PSOGPU.gpu_init_particles!(backend) |
| 38 | + # kernel!(particles, prob, opt, typeof(prob.u0); ndrange = opt.num_particles) |
| 39 | + |
| 40 | + # best_particle = minimum(particles) |
| 41 | + # _init_gbest = SPSOGBest(best_particle.best_position, best_particle.best_cost) |
| 42 | + |
| 43 | + # copyto!(cache.gbest, [_init_gbest]) |
| 44 | + |
| 45 | + return nothing |
| 46 | +end |
| 47 | + |
| 48 | +function Base.getproperty(cache::HybridPSOCache, name::Symbol) |
| 49 | + if name ∈ (:start_points, :pso_cache, :alg) |
| 50 | + return getfield(cache, name) |
| 51 | + else |
| 52 | + return getproperty(cache.pso_cache, name) |
| 53 | + end |
| 54 | +end |
| 55 | + |
| 56 | +function Base.setproperty!(cache::HybridPSOCache, name::Symbol, val) |
| 57 | + if name ∈ (:start_points, :pso_cache, :alg) |
| 58 | + return setfield!(cache, name, val) |
| 59 | + else |
| 60 | + return setproperty!(cache.pso_cache, name, val) |
| 61 | + end |
| 62 | +end |
| 63 | + |
| 64 | +function SciMLBase.solve!( |
| 65 | + cache::HybridPSOCache, opt::HybridPSO{Backend, LocalOpt}, args...; |
| 66 | + abstol = nothing, |
| 67 | + reltol = nothing, |
| 68 | + maxiters = 100, local_maxiters = 10, kwargs...) where { |
| 69 | + Backend, LocalOpt <: Union{LBFGS, BFGS}} |
| 70 | + |
| 71 | + pso_cache = cache.pso_cache |
| 72 | + |
| 73 | + sol_pso = solve!(pso_cache) |
22 | 74 | x0s = sol_pso.original |
23 | | - prob = remake(prob, lb = nothing, ub = nothing) |
| 75 | + |
| 76 | + backend = opt.backend |
| 77 | + |
| 78 | + prob = remake(cache.prob, lb = nothing, ub = nothing) |
24 | 79 | f = Base.Fix2(prob.f.f, prob.p) |
25 | 80 | ∇f = instantiate_gradient(f, prob.f.adtype) |
26 | 81 |
|
27 | 82 | kernel = simplebfgs_run!(backend) |
28 | | - result = KernelAbstractions.allocate(backend, typeof(prob.u0), length(x0s)) |
| 83 | + result = cache.start_points |
| 84 | + copyto!(result, x0s) |
29 | 85 | nlprob = NonlinearProblem{false}(∇f, prob.u0) |
30 | 86 |
|
31 | 87 | nlalg = LocalOpt isa LBFGS ? |
|
0 commit comments