Skip to content

Commit 6349e73

Browse files
authored
Merge pull request #8 from JuliaAI/dev
Add AdaptiveParticleSwarm strategy
2 parents 099441d + 5020e2e commit 6349e73

File tree

12 files changed

+443
-63
lines changed

12 files changed

+443
-63
lines changed

src/MLJParticleSwarmOptimization.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ using Distributions
66
using MLJBase
77
using MLJTuning
88

9-
export StaticCoeffs, ParticleSwarm
9+
export ParticleSwarm, AdaptiveParticleSwarm
1010

1111
include("swarm.jl")
1212
include("parameters.jl")
1313
include("update.jl")
14+
include("strategies/abstract.jl")
1415
include("strategies/basic.jl")
16+
include("strategies/adaptive.jl")
1517

1618
end

src/parameters.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,13 @@
44

55
# Initialize particle swarm state
66

7-
function initialize(
8-
r::Union{ParamRange, Tuple{ParamRange, Any}},
9-
tuning::AbstractParticleSwarm
10-
)
11-
return initialize([r], tuning)
7+
function initialize(rng::AbstractRNG, r::Union{ParamRange, Tuple{ParamRange, Any}}, n::Int)
8+
return initialize(rng, [r], n)
129
end
1310

14-
function initialize(rs::AbstractVector, tuning::AbstractParticleSwarm)
15-
n = tuning.n_particles
11+
function initialize(rng::AbstractRNG, rs::AbstractVector, n::Int)
1612
# `length` is 1 for `NumericRange` and the number of categories for `NominalRange`
17-
ranges, parameters, lengths, Xᵢ = zip(_initialize.(tuning.rng, rs, n)...)
13+
ranges, parameters, lengths, Xᵢ = zip(_initialize.(rng, rs, n)...)
1814
indices = _to_indices(lengths)
1915
X = hcat(Xᵢ...)
2016
V = zero(X)
@@ -102,9 +98,8 @@ end
10298
# For updating `state.parameters`, the model hyperparameters to be returned, from their
10399
# internal representation `state.X`
104100

105-
function retrieve!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
101+
function retrieve!(rng::AbstractRNG, state::ParticleSwarmState)
106102
ranges, params, indices, X = state.ranges, state.parameters, state.indices, state.X
107-
rng = tuning.rng
108103
for (r, p, i) in zip(ranges, params, indices)
109104
_retrieve!(rng, p, r, view(X, :, i))
110105
end

src/strategies/abstract.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function initialize(r, tuning::AbstractParticleSwarm)
2+
return initialize(tuning.rng, r, tuning.n_particles)
3+
end
4+
5+
function retrieve!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
6+
return retrieve!(tuning.rng, state)
7+
end
8+
9+
function pbest!(state::ParticleSwarmState, measurements, tuning::AbstractParticleSwarm)
10+
return pbest!(state, measurements, tuning.prob_shift)
11+
end

src/strategies/adaptive.jl

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
AdaptiveParticleSwarm(n_particles = 3,
3+
c1 = 2.0,
4+
c2 = 2.0,
5+
prob_shift = 0.25,
6+
rng = Random.GLOBAL_RNG)
7+
8+
Instantiate an adaptive particle swarm optimization tuning strategy. A swarm is
9+
initiated by sampling hyperparameters with their customizable priors, and new
10+
models are generated by referencing each member's and the swarm's best models so
11+
far.
12+
13+
### Supported Ranges and Discrete Hyperparameter Handling
14+
15+
See [`ParticleSwarm`](@ref) for more information about supported ranges and how
16+
discrete hyperparameters are handled.
17+
18+
### Algorithm
19+
20+
Hyperparameter ranges are sampled and concatenated into position vectors for
21+
each swarm particle. Velocity is initiated to be zeros, and in each iteration,
22+
every particle's position is updated to approach its personal best and the
23+
swarm's best models so far with the equations:
24+
25+
\$vₖ₊₁ = w⋅vₖ + c₁⋅rand()⋅(pbest - xₖ) + c₂⋅rand()⋅(gbest - xₖ)\$
26+
27+
\$xₖ₊₁ = xₖ + vₖ₊₁\$
28+
29+
Coefficients `w`, `c1`, `c2` are adaptively adjusted at each iteration by
30+
determining the evolutionary phase of the swarm. We calculate the evolutionary
31+
factor by comparing the mean distance from each particle to other members of the
32+
swarm. This factor is then used to classify whether the swarm is in exploration,
33+
exploitation, convergence, or jumping out phase and calibrate the tuning
34+
hyperparameters accordingly. For more information, refer to "Adaptive Particle
35+
Swarm Optimiztion" by Zhan, Zhang, Li, and Chung. Note that we omit the elitist
36+
learning strategy in the paper.
37+
38+
New models are then generated for evaluation by mutating the fields of a deep
39+
copy of `model`. If the corresponding range has a specified `scale` function,
40+
then the transformation is applied before the hyperparameter is returned. If
41+
`scale` is a symbol (eg, `:log`), it is ignored.
42+
"""
43+
mutable struct AdaptiveParticleSwarm{R<:AbstractRNG} <: AbstractParticleSwarm
44+
n_particles::Int
45+
c1::Float64
46+
c2::Float64
47+
prob_shift::Float64
48+
rng::R
49+
end
50+
51+
# Constructor
52+
53+
function AdaptiveParticleSwarm(;
54+
n_particles=3,
55+
c1=2.0,
56+
c2=2.0,
57+
prob_shift=0.25,
58+
rng::R=Random.GLOBAL_RNG
59+
) where {R}
60+
swarm = AdaptiveParticleSwarm{R}(n_particles, c1, c2, prob_shift, rng)
61+
message = MLJTuning.clean!(swarm)
62+
isempty(message) || @warn message
63+
return swarm
64+
end
65+
66+
# Validate tuning hyperparameters
67+
68+
function MLJTuning.clean!(tuning::AdaptiveParticleSwarm)
69+
warning = ""
70+
if tuning.n_particles < 3
71+
warning *= "AdaptiveParticleSwarm requires at least 3 particles. " *
72+
"Resetting n_particles=3. "
73+
tuning.n_particles = 3
74+
end
75+
c1, c2 = tuning.c1, tuning.c2
76+
if !(1.5 c1 2.5) || !(1.5 c2 2.5) || (c1 + c2 > 4)
77+
c1, c2 = _clamp_coefficients(c1, c2)
78+
warning *= "AdaptiveParticleSwarm requires 1.5 ≤ c1 ≤ 2.5, 1.5 ≤ c2 ≤ 2.5, and " *
79+
"c1 + c2 ≤ 4. Resetting coefficients c1=$(c1), c2=$(c2). "
80+
tuning.c1 = c1
81+
tuning.c2 = c2
82+
end
83+
if !(0 tuning.prob_shift < 1)
84+
warning *= "AdaptiveParticleSwarm requires 0 ≤ prob_shift < 1. " *
85+
"Resetting prob_shift=0.25. "
86+
tuning.prob_shift = 0.25
87+
end
88+
return warning
89+
end
90+
91+
# Helper function to clamp swarm coefficients in the interval [1.5, 2.5] with a sum of less
92+
# than or equal to 4
93+
94+
function _clamp_coefficients(c1, c2)
95+
c1 = min(max(c1, 1.5), 2.5)
96+
c2 = min(max(c2, 1.5), 2.5)
97+
scale = 4. / (c1 + c2)
98+
if scale < 1
99+
c1 *= scale
100+
c2 *= scale
101+
end
102+
return c1, c2
103+
end
104+
105+
# Initial state
106+
107+
function MLJTuning.setup(tuning::AdaptiveParticleSwarm, model, ranges, n, verbosity)
108+
# state, evolutionary phase, swarm coefficients
109+
return (initialize(ranges, tuning), nothing, tuning.c1, tuning.c2)
110+
end
111+
112+
# New models
113+
114+
function MLJTuning.models(
115+
tuning::AdaptiveParticleSwarm,
116+
model,
117+
history,
118+
(state, phase, c1, c2),
119+
n_remaining,
120+
verbosity
121+
)
122+
n_particles = tuning.n_particles
123+
if !isnothing(history)
124+
sig = MLJTuning.signature(history[1].measure[1])
125+
measurements = similar(state.pbest)
126+
map(history[end-n_particles+1:end]) do h
127+
measurements[h.metadata] = sig * h.measurement[1]
128+
end
129+
pbest!(state, measurements, tuning)
130+
gbest!(state)
131+
f = _evolutionary_factor(state.X, argmin(state.pbest))
132+
phase = _evolutionary_phase(f, phase)
133+
w, c1, c2 = _adapt_parameters(tuning.rng, c1, c2, f, phase)
134+
move!(tuning.rng, state, w, c1, c2)
135+
end
136+
retrieve!(state, tuning)
137+
fields = getproperty.(state.ranges, :field)
138+
new_models = map(1:n_particles) do i
139+
clone = deepcopy(model)
140+
for (field, param) in zip(fields, getindex.(state.parameters, i))
141+
recursive_setproperty!(clone, field, param)
142+
end
143+
(clone, i)
144+
end
145+
return new_models, (state, phase, c1, c2)
146+
end
147+
148+
# Helper function to calculate the evolutionary factor and phase
149+
150+
function _evolutionary_factor(X, gbest_i)
151+
n_particles = size(X, 1)
152+
dists = zeros(n_particles, n_particles)
153+
for i in 1:n_particles
154+
for j in i+1:n_particles
155+
dists[j, i] = dists[i, j] = norm(X[i, :] - X[j, :])
156+
end
157+
end
158+
mean_dists = sum(dists, dims=2) / (n_particles - 1)
159+
min_dist, max_dist = extrema(mean_dists)
160+
gbest_dist = mean_dists[gbest_i]
161+
f = (gbest_dist - min_dist) / max(max_dist - min_dist, sqrt(eps()))
162+
return f
163+
end
164+
165+
function _evolutionary_phase(f, phase)
166+
# Classify evolutionary phase
167+
μs = [μ₁(f), μ₂(f), μ₃(f), μ₄(f)]
168+
if phase === nothing # first iteration
169+
phase = argmax(μs)
170+
else
171+
next_phase = mod1(phase + 1, 4)
172+
# switch to next phase if possible
173+
if μs[next_phase] > 0
174+
phase = next_phase
175+
# stay in current phase is possible, else pick the most likely phase
176+
elseif μs[phase] == 0
177+
phase = argmax(μs)
178+
end
179+
end
180+
return phase
181+
end
182+
183+
# Helper functions to calculate probabilities of the four evolutionary states
184+
185+
μ₁(f) = f 0.4 ? 0.0 :
186+
f 0.6 ? 5 * f - 2 :
187+
f 0.7 ? 1 :
188+
f 0.8 ? -10 * f + 8 :
189+
0.0
190+
191+
μ₂(f) = f 0.2 ? 0.0 :
192+
f 0.3 ? 10 * f - 2 :
193+
f 0.4 ? 1.0 :
194+
f 0.6 ? -5 * f + 3 :
195+
0.0
196+
197+
μ₃(f) = f 0.1 ? 1.0 :
198+
f 0.3 ? -5 * f + 1.5 :
199+
0.0
200+
201+
μ₄(f) = f 0.7 ? 0.0 :
202+
f 0.9 ? 5 * f - 3.5 :
203+
1.0
204+
205+
# Adaptive control of swarm's parameters
206+
207+
function _adapt_parameters(rng, c1, c2, f, phase)
208+
w = 1.0 / (1.0 + 1.5*exp(-2.6 * f)) # update inertia
209+
δ = rand(rng) * 0.05 + 0.05 # coefficient acceleration
210+
if phase === 1 # exploration
211+
c1 += δ
212+
c2 -= δ
213+
elseif phase === 2 # exploitation
214+
δ *= 0.5
215+
c1 += δ
216+
c2 -= δ
217+
elseif phase === 3 # convergence
218+
δ *= 0.5
219+
c1 += δ
220+
c2 += δ
221+
else # jumping out
222+
c1 -= δ
223+
c2 += δ
224+
end
225+
c1, c2 = _clamp_coefficients(c1, c2)
226+
return w, c1, c2
227+
end

src/strategies/basic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Instantiate a particle swarm optimization tuning strategy. A swarm is initiated
1010
by sampling hyperparameters with their customizable priors, and new models are
1111
generated by referencing each member's and the swarm's best models so far.
1212
13-
### Supported ranges
13+
### Supported Ranges
1414
1515
A single one-dimensional range or vector of one-dimensional ranges can be
1616
specified. `ParamRange` objects are constructed using the `range` method. If not
@@ -171,9 +171,9 @@ function MLJTuning.models(
171171
map(history[end-n_particles+1:end]) do h
172172
measurements[h.metadata] = sign * h.measurement[1]
173173
end
174-
pbest!(state, tuning, measurements)
175-
gbest!(state, tuning)
176-
move!(state, tuning)
174+
pbest!(state, measurements, tuning)
175+
gbest!(state)
176+
move!(tuning.rng, state, T(tuning.w), T(tuning.c1), T(tuning.c2))
177177
end
178178
retrieve!(state, tuning)
179179
fields = getproperty.(state.ranges, :field)

src/update.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Move the swarm
22

3-
function move!(state::ParticleSwarmState{T}, tuning::AbstractParticleSwarm) where {T}
4-
rng, X, V = tuning.rng, state.X, state.V
5-
w, c1, c2 = T(tuning.w), T(tuning.c1), T(tuning.c2)
3+
function move!(rng::AbstractRNG, state::ParticleSwarmState{T}, w, c1, c2) where {T}
4+
X, V = state.X, state.V
65
@. V = w*V + c1*rand(rng, T)*(state.pbest_X - X) + c2*rand(rng, T)*(state.gbest_X - X)
76
X .+= V
87
for (r, idx) in zip(state.ranges, state.indices)
@@ -24,9 +23,8 @@ end
2423

2524
# Update pbest
2625

27-
function pbest!(state::ParticleSwarmState, tuning::AbstractParticleSwarm, measurements)
26+
function pbest!(state::ParticleSwarmState, measurements, prob_shift)
2827
X, pbest, pbest_X = state.X, state.pbest, state.pbest_X
29-
prob_shift = tuning.prob_shift
3028
improved = measurements .<= pbest
3129
pbest[improved] .= measurements[improved]
3230
for (r, p, i) in zip(state.ranges, state.parameters, state.indices)
@@ -48,7 +46,7 @@ end
4846

4947
# Update gbest
5048

51-
function gbest!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
49+
function gbest!(state::ParticleSwarmState)
5250
pbest, pbest_X, gbest, gbest_X = state.pbest, state.pbest_X, state.gbest, state.gbest_X
5351
best, i = findmin(pbest)
5452
gbest .= best

test/parameters.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,29 @@
6969
end
7070

7171
@testset "Initialize one range" begin
72-
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
72+
rng = StableRNG(1234)
7373
for (r, l, i, X) in zip(rs, lengths, indices, Xs)
74-
state = PSO.initialize(r, ps)
74+
state = PSO.initialize(rng, r, n)
7575
@test state.ranges == (r,)
7676
@test state.indices == (l == 1 ? 1 : 1:l,)
7777
@test state.X X
7878
end
7979
end
8080

8181
@testset "Initialize multiple ranges" begin
82-
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
82+
rng = StableRNG(1234)
8383
ranges = [r1, (r2, Uniform), (r3, d3), r4]
84-
state = PSO.initialize(ranges, ps)
84+
state = PSO.initialize(rng, ranges, n)
8585
@test state.ranges == rs
8686
@test state.indices == indices
8787
@test state.X hcat(Xs...)
8888
end
8989

9090
@testset "Retrieve parameters" begin
91-
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
91+
rng = StableRNG(1234)
9292
ranges = [r1, (r2, Uniform), (r3, d3), r4]
93-
state = PSO.initialize(ranges, ps)
94-
PSO.retrieve!(state, ps)
93+
state = PSO.initialize(rng, ranges, n)
94+
PSO.retrieve!(rng, state)
9595
@test state.parameters == (
9696
["a", "a", "c"],
9797
[553, 250, 375],

0 commit comments

Comments
 (0)