1
+ """
2
+ ParticleSwarm(n_particles = 3,
3
+ w = 1.0,
4
+ c1 = 2.0,
5
+ c2 = 2.0,
6
+ prob_shift = 0.25,
7
+ rng = Random.GLOBAL_RNG)
8
+
9
+ Instantiate a particle swarm optimization tuning strategy. A swarm is initiated
10
+ by sampling hyperparameters with their customizable priors, and new models are
11
+ generated by referencing each member's and the swarm's best models so far.
12
+
13
+ ### Supported ranges
14
+
15
+ A single one-dimensional range or vector of one-dimensional ranges can be
16
+ specified. `ParamRange` objects are constructed using the `range` method. If not
17
+ paired with a prior, then one is fitted, as follows:
18
+
19
+ | Range Types | Default Distribution |
20
+ |:----------------------- |:-------------------- |
21
+ | `NominalRange` | `Dirichlet` |
22
+ | Bounded `NumericRange` | `Uniform` |
23
+ | Positive `NumericRange` | `Gamma` |
24
+ | Other `NumericRange` | `Normal` |
25
+
26
+ Specifically, in `ParticleSwarm`, the `range` field of a `TunedModel` instance
27
+ can be:
28
+
29
+ - a single one-dimensional range (`ParamRange` object) `r`
30
+
31
+ - a pair of the form `(r, d)`, with `r` as above and where `d` is:
32
+
33
+ - a Dirichlet distribution with the same number of categories as `r.values`
34
+ (for `NominalRange` `r`)
35
+
36
+ - any `Distributions.UnivariateDistribution` *instance* (for `NumericRange`
37
+ `r`)
38
+
39
+ - one of the distribution *types* in the table below, for automatic fitting
40
+ using `Distributions.fit(d, r)` to a distribution whose support always
41
+ lies between `r.lower` and `r.upper` (for `NumericRange` `r`) or the set
42
+ of probability vectors (for `NominalRange` `r`)
43
+
44
+ - any vector of objects of the above form
45
+
46
+ | Range Types | Distribution Types |
47
+ |:----------------------- |:-------------------------------------------------------------------------------------------- |
48
+ | `NominalRange` | `Dirichlet` |
49
+ | Bounded `NumericRange` | `Arcsine`, `Uniform`, `Biweight`, `Cosine`, `Epanechnikov`, `SymTriangularDist`, `Triweight` |
50
+ | Positive `NumericRange` | `Gamma`, `InverseGaussian`, `Poisson` |
51
+ | Any `NumericRange` | `Normal`, `Logistic`, `LogNormal`, `Cauchy`, `Gumbel`, `Laplace` |
52
+
53
+ ### Examples
54
+
55
+ using Distributions
56
+
57
+ range1 = range(model, :hyper1, lower=0, upper=1)
58
+
59
+ range2 = [(range(model, :hyper1, lower=1, upper=10), Arcsine),
60
+ range(model, :hyper2, lower=2, upper=Inf, unit=1, origin=3),
61
+ (range(model, :hyper2, lower=2, upper=4), Normal(0, 3)),
62
+ (range(model, :hyper3, values=[:ball, :tree]), Dirichlet)]
63
+
64
+ ### Algorithm
65
+
66
+ Hyperparameter ranges are sampled and concatenated into position vectors for
67
+ each swarm particle. Velocity is initiated to be zeros, and in each iteration,
68
+ every particle's position is updated to approach its personal best and the
69
+ swarm's best models so far with the equations:
70
+
71
+ \$ vₖ₊₁ = w⋅vₖ + c₁⋅rand()⋅(pbest - x) + c₂⋅rand()⋅(gbest - x)\$
72
+
73
+ \$ xₖ₊₁ = xₖ + vₖ₊₁\$
74
+
75
+ New models are then generated for evaluation by mutating the fields of a deep
76
+ copy of `model`. If the corresponding range has a specified `scale` function,
77
+ then the transformation is applied before the hyperparameter is returned. For
78
+ integer `NumericRange`s, the hyperparameter is rounded; and for `NominalRange`s,
79
+ the hyperparameter is sampled from the specified values with the probability
80
+ weights given by each particle.
81
+
82
+ Personal and social best models are then updated for the swarm. In order to
83
+ replicate both the probability weights and the sampled value for `NominalRange`s
84
+ of the best models, the weights of unselected values are shifted to the selected
85
+ one by the `prob_shift` factor.
86
+ """
87
+ mutable struct ParticleSwarm{R<: AbstractRNG } <: AbstractParticleSwarm
88
+ n_particles:: Int
89
+ w:: Float64
90
+ c1:: Float64
91
+ c2:: Float64
92
+ prob_shift:: Float64
93
+ rng:: R
94
+ # TODO : topology
95
+ end
96
+
97
+ function ParticleSwarm (;
98
+ n_particles= 3 ,
99
+ w= 1.0 ,
100
+ c1= 2.0 ,
101
+ c2= 2.0 ,
102
+ prob_shift= 0.25 ,
103
+ rng:: R = Random. GLOBAL_RNG
104
+ ) where {R}
105
+ swarm = ParticleSwarm {R} (n_particles, w, c1, c2, prob_shift, rng)
106
+ message = MLJTuning. clean! (swarm)
107
+ isempty (message) || @warn message
108
+ return swarm
109
+ end
110
+
111
+ function MLJTuning. clean! (tuning:: ParticleSwarm )
112
+ warning = " "
113
+ if tuning. n_particles < 3
114
+ warning *= " ParticleSwarm requires at least 3 particles. Resetting n_particles=3. "
115
+ tuning. n_particles = 3
116
+ end
117
+ if tuning. w < 0
118
+ warning *= " ParticleSwarm requires w ≥ 0. Resetting w=1. "
119
+ tuning. w = 1
120
+ end
121
+ if tuning. c1 < 0
122
+ warning *= " ParticleSwarm requires c1 ≥ 0. Resetting c1=2. "
123
+ tuning. c1 = 2
124
+ end
125
+ if tuning. c2 < 0
126
+ warning *= " ParticleSwarm requires c2 ≥ 0. Resetting c2=2. "
127
+ tuning. c2 = 2
128
+ end
129
+ if ! (0 ≤ tuning. prob_shift < 1 )
130
+ warning *= " ParticleSwarm requires 0 ≤ prob_shift < 1. Resetting prob_shift=0.25. "
131
+ tuning. prob_shift = 0.25
132
+ end
133
+ return warning
134
+ end
135
+
136
+ function MLJTuning. setup (tuning:: ParticleSwarm , model, ranges, n, verbosity)
137
+ return initialize (ranges, tuning)
138
+ end
139
+
140
+ function MLJTuning. models (
141
+ tuning:: ParticleSwarm ,
142
+ model,
143
+ history,
144
+ state,
145
+ n_remaining,
146
+ verbosity
147
+ )
148
+ n_particles = tuning. n_particles
149
+ if ! isnothing (history)
150
+ sig = MLJTuning. signature (first (history). measure)
151
+ pbest! (state, tuning, map (h -> sig * h. measurement[1 ], last (history, n_particles)))
152
+ gbest! (state, tuning)
153
+ move! (state, tuning)
154
+ end
155
+ retrieve! (state, tuning)
156
+ fields = getproperty .(state. ranges, :field )
157
+ new_models = map (1 : n_particles) do i
158
+ clone = deepcopy (model)
159
+ for (field, param) in zip (fields, getindex .(state. parameters, i))
160
+ recursive_setproperty! (clone, field, param)
161
+ end
162
+ clone
163
+ end
164
+ return new_models, state
165
+ end
166
+
167
+ function MLJTuning. tuning_report (tuning:: ParticleSwarm , history, state)
168
+ fields = getproperty .(state. ranges, :field )
169
+ scales = MLJBase. scale .(state. ranges)
170
+ return (; plotting = MLJTuning. plotting_report (fields, scales, history))
171
+ end
0 commit comments