1
+ # ##
2
+ # ## Initialization
3
+ # ##
4
+
5
+ # Initialize particle swarm state
6
+
7
+ function initialize (
8
+ r:: Union{ParamRange, Tuple{ParamRange, Any}} ,
9
+ tuning:: AbstractParticleSwarm
10
+ )
11
+ return initialize ([r], tuning)
12
+ end
13
+
14
+ function initialize (rs:: AbstractVector , tuning:: AbstractParticleSwarm )
15
+ n = tuning. n_particles
16
+ # `length` is 1 for `NumericRange` and the number of categories for `NominalRange`
17
+ ranges, parameters, lengths, Xᵢ = zip (_initialize .(tuning. rng, rs, n)... )
18
+ indices = _to_indices (lengths)
19
+ X = hcat (Xᵢ... )
20
+ V = zero (X)
21
+ pbest_X = copy (X)
22
+ gbest_X = copy (X)
23
+ pbest = fill (eltype (X)(Inf ), n)
24
+ gbest = similar (pbest)
25
+ return ParticleSwarmState (
26
+ ranges, parameters, indices, X, V, pbest_X, gbest_X, pbest, gbest
27
+ )
28
+ end
29
+
30
+ # Unpack tuple of range and distribution
31
+
32
+ function _initialize (rng, t:: Tuple{ParamRange, Any} , n)
33
+ return _initialize (rng, t[1 ], t[2 ], n)
34
+ end
35
+
36
+ # Initialize parameters with default distributions
37
+
38
+ function _initialize (rng, r:: NominalRange{T, N} , n) where {T, N}
39
+ d = Dirichlet (ones (N))
40
+ return _initialize (rng, r, d, n)
41
+ end
42
+
43
+ function _initialize (rng, r:: NumericRange , n)
44
+ D = _initializer (MLJTuning. boundedness (r))
45
+ return _initialize (rng, r, D, n)
46
+ end
47
+
48
+ _initializer (:: Type{MLJBase.Bounded} ) = Uniform
49
+
50
+ _initializer (:: Type{MLJTuning.PositiveUnbounded} ) = Gamma
51
+
52
+ _initializer (:: Type{MLJTuning.Other} ) = Normal
53
+
54
+ # Fit distributions and initialize parameters
55
+
56
+ function _initialize (rng, r:: ParamRange , D:: Type{<:Distribution} , n)
57
+ throw (ArgumentError (" $D distribution is unsupported for $(typeof (r)) ." ))
58
+ end
59
+
60
+ function _initialize (rng, r:: NumericRange , D:: Type{<:UnivariateDistribution} , n)
61
+ d = Distributions. fit (D, r)
62
+ return _initialize (rng, r, d, n)
63
+ end
64
+
65
+ # Initialize parameters with fitted/provided distributions
66
+
67
+ function _initialize (rng, r:: ParamRange , d:: Distribution , n)
68
+ throw (ArgumentError (" $(typeof (d)) distribution is unsupported for $(typeof (r)) ." ))
69
+ end
70
+
71
+ function _initialize (rng, r:: NominalRange{T, N} , d:: Dirichlet , n) where {T, N}
72
+ N != d. alpha0 &&
73
+ throw (ArgumentError (" Provided distribution's number of categories don't match $r ." ))
74
+ p = Vector {T} (undef, n)
75
+ X = rand (rng, d, n)'
76
+ return r, p, N, X
77
+ end
78
+
79
+ function _initialize (rng, r:: NumericRange{T} , d:: UnivariateDistribution , n) where {T}
80
+ p = Vector {T} (undef, n)
81
+ X = rand (rng, d, n)
82
+ return r, p, 1 , X
83
+ end
84
+
85
+ # Helper function to get ranges' corresponding indices
86
+ # E.g., `_to_indices((1, 2, 1, 3))` returns `(1, 2:3, 4, 5:7)`
87
+
88
+ function _to_indices (lengths)
89
+ curr = 1
90
+ return map (lengths) do length
91
+ start = curr
92
+ stop = start + length - 1
93
+ curr = stop + 1
94
+ start == stop ? stop : (start: stop)
95
+ end
96
+ end
97
+
98
+ # ##
99
+ # ## Retrieval
100
+ # ##
101
+
102
+ # For updating `state.parameters`, the model hyperparameters to be returned, from their
103
+ # internal representation `state.X`
104
+
105
+ function retrieve! (state:: ParticleSwarmState , tuning:: AbstractParticleSwarm )
106
+ ranges, params, indices, X = state. ranges, state. parameters, state. indices, state. X
107
+ rng = tuning. rng
108
+ for (r, p, i) in zip (ranges, params, indices)
109
+ _retrieve! (rng, p, r, view (X, :, i))
110
+ end
111
+ return state
112
+ end
113
+
114
+ function _retrieve! (rng, p, r:: NominalRange , X)
115
+ return p .= getindex .(
116
+ Ref (r. values),
117
+ rand .(rng, Categorical .(X[i,:] for i in axes (X, 1 )))
118
+ )
119
+ end
120
+
121
+ function _retrieve! (rng, p, r:: NumericRange{T} , X) where {T<: Integer }
122
+ return @. p = round (T, _transform (r. scale, X))
123
+ end
124
+
125
+ function _retrieve! (rng, p, r:: NumericRange{T} , X) where {T<: Real }
126
+ return @. p = _transform (r. scale, X)
127
+ end
128
+
129
+ _transform (:: Symbol , X) = X
130
+
131
+ _transform (scale, X) = scale (X)
0 commit comments