Skip to content

Commit f528b3a

Browse files
committed
Update caching in hybrid pso
1 parent 64c8b3f commit f528b3a

File tree

2 files changed

+75
-43
lines changed

2 files changed

+75
-43
lines changed

src/hybrid.jl

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,83 @@
55
result[i] = sol.u
66
end
77

8-
function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem,
9-
opt::HybridPSO{Backend, LocalOpt},
10-
args...;
11-
abstol = nothing,
12-
reltol = nothing,
13-
maxiters = 100,
14-
local_maxiters = 10,
8+
struct HybridPSOCache{TPc, TSp, TAlg}
9+
pso_cache::TPc
10+
start_points::TSp
11+
alg::TAlg
12+
end
13+
14+
function SciMLBase.init(
15+
prob::OptimizationProblem, opt::HybridPSO{Backend, LocalOpt}, args...;
1516
kwargs...) where {Backend, LocalOpt <: Union{LBFGS, BFGS}}
16-
t0 = time()
1717
psoalg = opt.pso
18-
local_opt = opt.local_opt
1918
backend = opt.backend
2019

21-
sol_pso = solve(prob, psoalg, args...; maxiters, kwargs...)
20+
pso_cache = init(prob, psoalg)
21+
22+
start_points = KernelAbstractions.allocate(
23+
backend, typeof(prob.u0), opt.pso.num_particles)
24+
25+
return HybridPSOCache{
26+
typeof(pso_cache), typeof(start_points), typeof(opt)}(pso_cache, start_points, opt)
27+
end
28+
29+
function reinit_cache!(cache::HybridPSOCache,
30+
opt::HybridPSO{Backend, LocalOpt}) where {Backend, LocalOpt <: Union{LBFGS, BFGS}}
31+
reinit!(cache.pso_cache)
32+
fill!(cache.start_points, zero(eltype(cache.start_points)))
33+
# prob = cache.prob
34+
# backend = opt.backend
35+
# particles = cache.particles
36+
37+
# kernel! = PSOGPU.gpu_init_particles!(backend)
38+
# kernel!(particles, prob, opt, typeof(prob.u0); ndrange = opt.num_particles)
39+
40+
# best_particle = minimum(particles)
41+
# _init_gbest = SPSOGBest(best_particle.best_position, best_particle.best_cost)
42+
43+
# copyto!(cache.gbest, [_init_gbest])
44+
45+
return nothing
46+
end
47+
48+
function Base.getproperty(cache::HybridPSOCache, name::Symbol)
49+
if name (:start_points, :pso_cache, :alg)
50+
return getfield(cache, name)
51+
else
52+
return getproperty(cache.pso_cache, name)
53+
end
54+
end
55+
56+
function Base.setproperty!(cache::HybridPSOCache, name::Symbol, val)
57+
if name (:start_points, :pso_cache, :alg)
58+
return setfield!(cache, name, val)
59+
else
60+
return setproperty!(cache.pso_cache, name, val)
61+
end
62+
end
63+
64+
function SciMLBase.solve!(
65+
cache::HybridPSOCache, opt::HybridPSO{Backend, LocalOpt}, args...;
66+
abstol = nothing,
67+
reltol = nothing,
68+
maxiters = 100, local_maxiters = 10, kwargs...) where {
69+
Backend, LocalOpt <: Union{LBFGS, BFGS}}
70+
71+
pso_cache = cache.pso_cache
72+
73+
sol_pso = solve!(pso_cache)
2274
x0s = sol_pso.original
23-
prob = remake(prob, lb = nothing, ub = nothing)
75+
76+
backend = opt.backend
77+
78+
prob = remake(cache.prob, lb = nothing, ub = nothing)
2479
f = Base.Fix2(prob.f.f, prob.p)
2580
∇f = instantiate_gradient(f, prob.f.adtype)
2681

2782
kernel = simplebfgs_run!(backend)
28-
result = KernelAbstractions.allocate(backend, typeof(prob.u0), length(x0s))
83+
result = cache.start_points
84+
copyto!(result, x0s)
2985
nlprob = NonlinearProblem{false}(∇f, prob.u0)
3086

3187
nlalg = LocalOpt isa LBFGS ?

src/solve.jl

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,6 @@ function SciMLBase.init(
3535
prob, opt, particles, init_gbest)
3636
end
3737

38-
function SciMLBase.init(
39-
prob::OptimizationProblem, opt::ParallelPSOKernel, args...; kwargs...)
40-
backend = opt.backend
41-
@assert prob.u0 isa SArray
42-
43-
## Bounds check
44-
lb, ub = check_init_bounds(prob)
45-
lb, ub = check_init_bounds(prob)
46-
prob = remake(prob; lb = lb, ub = ub)
47-
48-
particles = KernelAbstractions.allocate(
49-
backend, SPSOParticle{typeof(prob.u0), eltype(typeof(prob.u0))}, opt.num_particles)
50-
kernel! = gpu_init_particles!(backend)
51-
52-
kernel!(particles, prob, opt, typeof(prob.u0); ndrange = opt.num_particles)
53-
54-
best_particle = minimum(particles)
55-
_init_gbest = SPSOGBest(best_particle.best_position, best_particle.best_cost)
56-
57-
init_gbest = KernelAbstractions.allocate(backend, typeof(_init_gbest), (1,))
58-
copyto!(init_gbest, [_init_gbest])
59-
return PSOCache{
60-
typeof(prob), typeof(opt), typeof(particles), typeof(init_gbest)}(
61-
prob, opt, particles, init_gbest)
62-
end
63-
6438
function SciMLBase.init(
6539
prob::OptimizationProblem, opt::ParallelSyncPSOKernel, args...; kwargs...)
6640
backend = opt.backend
@@ -85,7 +59,7 @@ function SciMLBase.init(
8559
prob, opt, particles, init_gbest)
8660
end
8761

88-
function SciMLBase.reinit!(cache::PSOCache; kwargs...)
62+
function SciMLBase.reinit!(cache::Union{PSOCache, HybridPSOCache}; kwargs...)
8963
reinit_cache!(cache, cache.alg)
9064
end
9165

@@ -122,7 +96,8 @@ function reinit_cache!(cache::PSOCache, opt::ParallelSyncPSOKernel)
12296
return nothing
12397
end
12498

125-
function SciMLBase.solve!(cache, args...; maxiters = 100, kwargs...)
99+
function SciMLBase.solve!(
100+
cache::Union{PSOCache, HybridPSOCache}, args...; maxiters = 100, kwargs...)
126101
solve!(cache, cache.alg, args...; maxiters, kwargs...)
127102
end
128103

@@ -150,8 +125,8 @@ function SciMLBase.solve!(
150125
prob = cache.prob
151126
t0 = time()
152127
gbest, particles = vectorized_solve!(prob,
153-
init_gbest,
154-
gpu_particles,
128+
cache.gbest,
129+
cache.particles,
155130
opt,
156131
args...;
157132
kwargs...)
@@ -163,7 +138,8 @@ function SciMLBase.solve!(
163138
stats = Optimization.OptimizationStats(; time = t1 - t0))
164139
end
165140

166-
function SciMLBase.solve(prob::OptimizationProblem, opt::ParallelPSOKernel,
141+
function SciMLBase.solve(prob::OptimizationProblem,
142+
opt::Union{ParallelPSOKernel, ParallelSyncPSOKernel, HybridPSO},
167143
args...; maxiters = 100, kwargs...)
168144
solve!(init(prob, opt, args...; maxiters, kwargs...), opt)
169145
end

0 commit comments

Comments
 (0)