Skip to content

Commit 021fe77

Browse files
committed
Add AbstractParticleSwarm type
1 parent 1e1eabc commit 021fe77

File tree

10 files changed

+214
-226
lines changed

10 files changed

+214
-226
lines changed

src/ParticleSwarmOptimization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ParticleSwarmOptimization
22

3+
using LinearAlgebra
34
using Random
45
using Distributions
56
using MLJBase
@@ -10,6 +11,6 @@ export StaticCoeffs, ParticleSwarm
1011
include("swarm.jl")
1112
include("parameters.jl")
1213
include("update.jl")
13-
include("tuning.jl")
14+
include("strategies/basic.jl")
1415

1516
end

src/parameters.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
# Initialize particle swarm state
66

7-
function initialize(r::Union{ParamRange, Tuple{ParamRange, Any}}, ps::ParticleSwarm)
8-
return initialize([r], ps)
7+
function initialize(
8+
r::Union{ParamRange, Tuple{ParamRange, Any}},
9+
tuning::AbstractParticleSwarm
10+
)
11+
return initialize([r], tuning)
912
end
1013

11-
function initialize(rs::AbstractVector, ps::ParticleSwarm)
12-
n = ps.n_particles
13-
ranges, parameters, lens, Xᵢ = zip(_initialize.(Ref(ps.rng), rs, n)...) # wrapped in Ref for compat with Julia 1.0
14+
function initialize(rs::AbstractVector, tuning::AbstractParticleSwarm)
15+
n = tuning.n_particles
16+
# Wrap rng in Ref for compatibility with Julia <= 1.3
17+
ranges, parameters, lens, Xᵢ = zip(_initialize.(Ref(tuning.rng), rs, n)...)
1418
indices = _to_indices(lens)
1519
X = hcat(Xᵢ...)
1620
V = zero(X)
@@ -95,9 +99,9 @@ end
9599
### Retrieval
96100
###
97101

98-
function retrieve!(state::ParticleSwarmState, ps::ParticleSwarm)
102+
function retrieve!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
99103
ranges, params, indices, X = state.ranges, state.parameters, state.indices, state.X
100-
rng = ps.rng
104+
rng = tuning.rng
101105
for (r, p, i) in zip(ranges, params, indices)
102106
_retrieve!(rng, p, r, view(X, :, i))
103107
end
@@ -107,7 +111,8 @@ end
107111
function _retrieve!(rng, p, r::NominalRange, X)
108112
return p .= getindex.(
109113
Ref(r.values),
110-
rand.(Ref(rng), Categorical.(X[i,:] for i in axes(X, 1))) # wrapped in Ref for compat with Julia 1.0
114+
# Wrap rng in Ref for compatibility with Julia <= 1.3
115+
rand.(Ref(rng), Categorical.(X[i,:] for i in axes(X, 1)))
111116
)
112117
end
113118

src/strategies/basic.jl

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
ParticleSwarm(n_particles = 3,
3+
w = 1.0,
4+
c1 = 2.0,
5+
c2 = 2.0,
6+
prob_shift = 0.25,
7+
rng = Random.GLOBAL_RNG)
8+
9+
Instantiate a particle swarm optimization tuning strategy. A swarm is initiated
10+
by sampling hyperparameters with their customizable priors, and new models are
11+
generated by referencing each member's and the swarm's best models so far.
12+
13+
### Supported ranges
14+
15+
A single one-dimensional range or vector of one-dimensional ranges can be
16+
specified. `ParamRange` objects are constructed using the `range` method. If not
17+
paired with a prior, then one is fitted, as follows:
18+
19+
| Range Types | Default Distribution |
20+
|:----------------------- |:-------------------- |
21+
| `NominalRange` | `Dirichlet` |
22+
| Bounded `NumericRange` | `Uniform` |
23+
| Positive `NumericRange` | `Gamma` |
24+
| Other `NumericRange` | `Normal` |
25+
26+
Specifically, in `ParticleSwarm`, the `range` field of a `TunedModel` instance
27+
can be:
28+
29+
- a single one-dimensional range (`ParamRange` object) `r`
30+
31+
- a pair of the form `(r, d)`, with `r` as above and where `d` is:
32+
33+
- a Dirichlet distribution with the same number of categories as `r.values`
34+
(for `NominalRange` `r`)
35+
36+
- any `Distributions.UnivariateDistribution` *instance* (for `NumericRange`
37+
`r`)
38+
39+
- one of the distribution *types* in the table below, for automatic fitting
40+
using `Distributions.fit(d, r)` to a distribution whose support always
41+
lies between `r.lower` and `r.upper` (for `NumericRange` `r`) or the set
42+
of probability vectors (for `NominalRange` `r`)
43+
44+
- any vector of objects of the above form
45+
46+
| Range Types | Distribution Types |
47+
|:----------------------- |:-------------------------------------------------------------------------------------------- |
48+
| `NominalRange` | `Dirichlet` |
49+
| Bounded `NumericRange` | `Arcsine`, `Uniform`, `Biweight`, `Cosine`, `Epanechnikov`, `SymTriangularDist`, `Triweight` |
50+
| Positive `NumericRange` | `Gamma`, `InverseGaussian`, `Poisson` |
51+
| Any `NumericRange` | `Normal`, `Logistic`, `LogNormal`, `Cauchy`, `Gumbel`, `Laplace` |
52+
53+
### Examples
54+
55+
using Distributions
56+
57+
range1 = range(model, :hyper1, lower=0, upper=1)
58+
59+
range2 = [(range(model, :hyper1, lower=1, upper=10), Arcsine),
60+
range(model, :hyper2, lower=2, upper=Inf, unit=1, origin=3),
61+
(range(model, :hyper2, lower=2, upper=4), Normal(0, 3)),
62+
(range(model, :hyper3, values=[:ball, :tree]), Dirichlet)]
63+
64+
### Algorithm
65+
66+
Hyperparameter ranges are sampled and concatenated into position vectors for
67+
each swarm particle. Velocity is initiated to be zeros, and in each iteration,
68+
every particle's position is updated to approach its personal best and the
69+
swarm's best models so far with the equations:
70+
71+
\$vₖ₊₁ = w⋅vₖ + c₁⋅rand()⋅(pbest - x) + c₂⋅rand()⋅(gbest - x)\$
72+
73+
\$xₖ₊₁ = xₖ + vₖ₊₁\$
74+
75+
New models are then generated for evaluation by mutating the fields of a deep
76+
copy of `model`. If the corresponding range has a specified `scale` function,
77+
then the transformation is applied before the hyperparameter is returned. For
78+
integer `NumericRange`s, the hyperparameter is rounded; and for `NominalRange`s,
79+
the hyperparameter is sampled from the specified values with the probability
80+
weights given by each particle.
81+
82+
Personal and social best models are then updated for the swarm. In order to
83+
replicate both the probability weights and the sampled value for `NominalRange`s
84+
of the best models, the weights of unselected values are shifted to the selected
85+
one by the `prob_shift` factor.
86+
"""
87+
mutable struct ParticleSwarm{R<:AbstractRNG} <: AbstractParticleSwarm
88+
n_particles::Int
89+
w::Float64
90+
c1::Float64
91+
c2::Float64
92+
prob_shift::Float64
93+
rng::R
94+
# TODO: topology
95+
end
96+
97+
function ParticleSwarm(;
98+
n_particles=3,
99+
w=1.0,
100+
c1=2.0,
101+
c2=2.0,
102+
prob_shift=0.25,
103+
rng::R=Random.GLOBAL_RNG
104+
) where {R}
105+
swarm = ParticleSwarm{R}(n_particles, w, c1, c2, prob_shift, rng)
106+
message = MLJTuning.clean!(swarm)
107+
isempty(message) || @warn message
108+
return swarm
109+
end
110+
111+
function MLJTuning.clean!(tuning::ParticleSwarm)
112+
warning = ""
113+
if tuning.n_particles < 3
114+
warning *= "ParticleSwarm requires at least 3 particles. Resetting n_particles=3. "
115+
tuning.n_particles = 3
116+
end
117+
if tuning.w < 0
118+
warning *= "ParticleSwarm requires w ≥ 0. Resetting w=1. "
119+
tuning.w = 1
120+
end
121+
if tuning.c1 < 0
122+
warning *= "ParticleSwarm requires c1 ≥ 0. Resetting c1=2. "
123+
tuning.c1 = 2
124+
end
125+
if tuning.c2 < 0
126+
warning *= "ParticleSwarm requires c2 ≥ 0. Resetting c2=2. "
127+
tuning.c2 = 2
128+
end
129+
if !(0 tuning.prob_shift < 1)
130+
warning *= "ParticleSwarm requires 0 ≤ prob_shift < 1. Resetting prob_shift=0.25. "
131+
tuning.prob_shift = 0.25
132+
end
133+
return warning
134+
end
135+
136+
function MLJTuning.setup(tuning::ParticleSwarm, model, ranges, n, verbosity)
137+
return initialize(ranges, tuning)
138+
end
139+
140+
function MLJTuning.models(
141+
tuning::ParticleSwarm,
142+
model,
143+
history,
144+
state,
145+
n_remaining,
146+
verbosity
147+
)
148+
n_particles = tuning.n_particles
149+
if !isnothing(history)
150+
sig = MLJTuning.signature(first(history).measure)
151+
pbest!(state, tuning, map(h -> sig * h.measurement[1], last(history, n_particles)))
152+
gbest!(state, tuning)
153+
move!(state, tuning)
154+
end
155+
retrieve!(state, tuning)
156+
fields = getproperty.(state.ranges, :field)
157+
new_models = map(1:n_particles) do i
158+
clone = deepcopy(model)
159+
for (field, param) in zip(fields, getindex.(state.parameters, i))
160+
recursive_setproperty!(clone, field, param)
161+
end
162+
clone
163+
end
164+
return new_models, state
165+
end
166+
167+
function MLJTuning.tuning_report(tuning::ParticleSwarm, history, state)
168+
fields = getproperty.(state.ranges, :field)
169+
scales = MLJBase.scale.(state.ranges)
170+
return (; plotting = MLJTuning.plotting_report(fields, scales, history))
171+
end

src/swarm.jl

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,4 @@
1-
"""
2-
ParticleSwarm(n_particles = 3,
3-
w = 1.0,
4-
c1 = 2.0,
5-
c2 = 2.0,
6-
prob_shift = 0.25,
7-
rng = Random.GLOBAL_RNG)
8-
9-
Instantiate a particle swarm optimization tuning strategy. A swarm is initiated
10-
by sampling hyperparameters with their customizable priors, and new models are
11-
generated by referencing each member's and the swarm's best models so far.
12-
13-
### Supported ranges
14-
15-
A single one-dimensional range or vector of one-dimensional ranges can be
16-
specified. `ParamRange` objects are constructed using the `range` method. If not
17-
paired with a prior, then one is fitted, as follows:
18-
19-
| Range Types | Default Distribution |
20-
|:----------------------- |:-------------------- |
21-
| `NominalRange` | `Dirichlet` |
22-
| Bounded `NumericRange` | `Uniform` |
23-
| Positive `NumericRange` | `Gamma` |
24-
| Other `NumericRange` | `Normal` |
25-
26-
Specifically, in `ParticleSwarm`, the `range` field of a `TunedModel` instance
27-
can be:
28-
29-
- a single one-dimensional range (`ParamRange` object) `r`
30-
31-
- a pair of the form `(r, d)`, with `r` as above and where `d` is:
32-
33-
- a Dirichlet distribution with the same number of categories as `r.values`
34-
(for `NominalRange` `r`)
35-
36-
- any `Distributions.UnivariateDistribution` *instance* (for `NumericRange`
37-
`r`)
38-
39-
- one of the distribution *types* in the table below, for automatic fitting
40-
using `Distributions.fit(d, r)` to a distribution whose support always
41-
lies between `r.lower` and `r.upper` (for `NumericRange` `r`) or the set
42-
of probability vectors (for `NominalRange` `r`)
43-
44-
- any vector of objects of the above form
45-
46-
| Range Types | Distribution Types |
47-
|:----------------------- |:-------------------------------------------------------------------------------------------- |
48-
| `NominalRange` | `Dirichlet` |
49-
| Bounded `NumericRange` | `Arcsine`, `Uniform`, `Biweight`, `Cosine`, `Epanechnikov`, `SymTriangularDist`, `Triweight` |
50-
| Positive `NumericRange` | `Gamma`, `InverseGaussian`, `Poisson` |
51-
| Any `NumericRange` | `Normal`, `Logistic`, `LogNormal`, `Cauchy`, `Gumbel`, `Laplace` |
52-
53-
### Examples
54-
55-
using Distributions
56-
57-
range1 = range(model, :hyper1, lower=0, upper=1)
58-
59-
range2 = [(range(model, :hyper1, lower=1, upper=10), Arcsine),
60-
range(model, :hyper2, lower=2, upper=Inf, unit=1, origin=3),
61-
(range(model, :hyper2, lower=2, upper=4), Normal(0, 3)),
62-
(range(model, :hyper3, values=[:ball, :tree]), Dirichlet)]
63-
64-
### Algorithm
65-
66-
Hyperparameter ranges are sampled and concatenated into position vectors for
67-
each swarm particle. Velocity is initiated to be zeros, and in each iteration,
68-
every particle's position is updated to approach its personal best and the
69-
swarm's best models so far with the equations:
70-
71-
\$vₖ₊₁ = w⋅vₖ + c₁⋅rand()⋅(pbest - x) + c₂⋅rand()⋅(gbest - x)\$
72-
73-
\$xₖ₊₁ = xₖ + vₖ₊₁\$
74-
75-
New models are then generated for evaluation by mutating the fields of a deep
76-
copy of `model`. If the corresponding range has a specified `scale` function,
77-
then the transformation is applied before the hyperparameter is returned. For
78-
integer `NumericRange`s, the hyperparameter is rounded; and for `NominalRange`s,
79-
the hyperparameter is sampled from the specified values with the probability
80-
weights given by each particle.
81-
82-
Personal and social best models are then updated for the swarm. In order to
83-
replicate both the probability weights and the sampled value for `NominalRange`s
84-
of the best models, the weights of unselected values are shifted to the selected
85-
one by the `prob_shift` factor.
86-
"""
87-
mutable struct ParticleSwarm{T, R<:AbstractRNG} <: MLJTuning.TuningStrategy
88-
n_particles::Int
89-
w::T
90-
c1::T
91-
c2::T
92-
prob_shift::T
93-
rng::R
94-
# TODO: topology
95-
end
96-
97-
function ParticleSwarm(;
98-
n_particles::Int=3,
99-
w=1.0,
100-
c1=2.0,
101-
c2=2.0,
102-
prob_shift=0.25,
103-
rng::R=Random.GLOBAL_RNG
104-
) where {R}
105-
T = promote_type(typeof(inv(w)), typeof.((w, c1, c2, prob_shift))...)
106-
swarm = ParticleSwarm{T, R}(n_particles, w, c1, c2, prob_shift, rng)
107-
message = MLJTuning.clean!(swarm)
108-
isempty(message) || @warn message
109-
return swarm
110-
end
1+
abstract type AbstractParticleSwarm <: MLJTuning.TuningStrategy end
1112

1123
struct ParticleSwarmState{T, R, P, I}
1134
ranges::R

0 commit comments

Comments
 (0)