Skip to content

Commit 3b12172

Browse files
authored
Merge pull request #5 from JuliaAI/dev
Add naive particle swarm tuning strategy
2 parents e71adce + 5ab2e55 commit 3b12172

File tree

15 files changed

+745
-18
lines changed

15 files changed

+745
-18
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.0'
20+
- '1.4'
2121
- '1'
2222
os:
2323
- ubuntu-latest
@@ -41,6 +41,8 @@ jobs:
4141
${{ runner.os }}-
4242
- uses: julia-actions/julia-buildpkg@v1
4343
- uses: julia-actions/julia-runtest@v1
44+
env:
45+
JULIA_NUM_THREADS: '2'
4446
- uses: julia-actions/julia-processcoverage@v1
4547
- uses: codecov/codecov-action@v1
4648
with:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
/Manifest.toml
2+
/test/Manifest.toml

Project.toml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
name = "ParticleSwarmOptimization"
1+
name = "MLJParticleSwarmOptimization"
22
uuid = "17a086e9-ed03-4f30-ab88-8b63f0f6126c"
33
authors = ["Long Nguyen <[email protected]> and contributors"]
44
version = "0.1.0"
55

6-
[compat]
7-
julia = "1"
8-
9-
[extras]
10-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6+
[deps]
7+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
10+
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112

12-
[targets]
13-
test = ["Test"]
13+
[compat]
14+
Distributions = "0.25"
15+
MLJBase = "0.18"
16+
MLJTuning = "0.6"
17+
julia = "1.4"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module MLJParticleSwarmOptimization
2+
3+
using LinearAlgebra
4+
using Random
5+
using Distributions
6+
using MLJBase
7+
using MLJTuning
8+
9+
export StaticCoeffs, ParticleSwarm
10+
11+
include("swarm.jl")
12+
include("parameters.jl")
13+
include("update.jl")
14+
include("strategies/basic.jl")
15+
16+
end

src/ParticleSwarmOptimization.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/parameters.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
###
2+
### Initialization
3+
###
4+
5+
# Initialize particle swarm state
6+
7+
function initialize(
8+
r::Union{ParamRange, Tuple{ParamRange, Any}},
9+
tuning::AbstractParticleSwarm
10+
)
11+
return initialize([r], tuning)
12+
end
13+
14+
function initialize(rs::AbstractVector, tuning::AbstractParticleSwarm)
15+
n = tuning.n_particles
16+
# `length` is 1 for `NumericRange` and the number of categories for `NominalRange`
17+
ranges, parameters, lengths, Xᵢ = zip(_initialize.(tuning.rng, rs, n)...)
18+
indices = _to_indices(lengths)
19+
X = hcat(Xᵢ...)
20+
V = zero(X)
21+
pbest_X = copy(X)
22+
gbest_X = copy(X)
23+
pbest = fill(eltype(X)(Inf), n)
24+
gbest = similar(pbest)
25+
return ParticleSwarmState(
26+
ranges, parameters, indices, X, V, pbest_X, gbest_X, pbest, gbest
27+
)
28+
end
29+
30+
# Unpack tuple of range and distribution
31+
32+
function _initialize(rng, t::Tuple{ParamRange, Any}, n)
33+
return _initialize(rng, t[1], t[2], n)
34+
end
35+
36+
# Initialize parameters with default distributions
37+
38+
function _initialize(rng, r::NominalRange{T, N}, n) where {T, N}
39+
d = Dirichlet(ones(N))
40+
return _initialize(rng, r, d, n)
41+
end
42+
43+
function _initialize(rng, r::NumericRange, n)
44+
D = _initializer(MLJTuning.boundedness(r))
45+
return _initialize(rng, r, D, n)
46+
end
47+
48+
_initializer(::Type{MLJBase.Bounded}) = Uniform
49+
50+
_initializer(::Type{MLJTuning.PositiveUnbounded}) = Gamma
51+
52+
_initializer(::Type{MLJTuning.Other}) = Normal
53+
54+
# Fit distributions and initialize parameters
55+
56+
function _initialize(rng, r::ParamRange, D::Type{<:Distribution}, n)
57+
throw(ArgumentError("$D distribution is unsupported for $(typeof(r))."))
58+
end
59+
60+
function _initialize(rng, r::NumericRange, D::Type{<:UnivariateDistribution}, n)
61+
d = Distributions.fit(D, r)
62+
return _initialize(rng, r, d, n)
63+
end
64+
65+
# Initialize parameters with fitted/provided distributions
66+
67+
function _initialize(rng, r::ParamRange, d::Distribution, n)
68+
throw(ArgumentError("$(typeof(d)) distribution is unsupported for $(typeof(r))."))
69+
end
70+
71+
function _initialize(rng, r::NominalRange{T, N}, d::Dirichlet, n) where {T, N}
72+
N != d.alpha0 &&
73+
throw(ArgumentError("Provided distribution's number of categories don't match $r."))
74+
p = Vector{T}(undef, n)
75+
X = rand(rng, d, n)'
76+
return r, p, N, X
77+
end
78+
79+
function _initialize(rng, r::NumericRange{T}, d::UnivariateDistribution, n) where {T}
80+
p = Vector{T}(undef, n)
81+
X = rand(rng, d, n)
82+
return r, p, 1, X
83+
end
84+
85+
# Helper function to get ranges' corresponding indices
86+
# E.g., `_to_indices((1, 2, 1, 3))` returns `(1, 2:3, 4, 5:7)`
87+
88+
function _to_indices(lengths)
89+
curr = 1
90+
return map(lengths) do length
91+
start = curr
92+
stop = start + length - 1
93+
curr = stop + 1
94+
start == stop ? stop : (start:stop)
95+
end
96+
end
97+
98+
###
99+
### Retrieval
100+
###
101+
102+
# For updating `state.parameters`, the model hyperparameters to be returned, from their
103+
# internal representation `state.X`
104+
105+
function retrieve!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
106+
ranges, params, indices, X = state.ranges, state.parameters, state.indices, state.X
107+
rng = tuning.rng
108+
for (r, p, i) in zip(ranges, params, indices)
109+
_retrieve!(rng, p, r, view(X, :, i))
110+
end
111+
return state
112+
end
113+
114+
function _retrieve!(rng, p, r::NominalRange, X)
115+
return p .= getindex.(
116+
Ref(r.values),
117+
rand.(rng, Categorical.(X[i,:] for i in axes(X, 1)))
118+
)
119+
end
120+
121+
function _retrieve!(rng, p, r::NumericRange{T}, X) where {T<:Integer}
122+
return @. p = round(T, _transform(r.scale, X))
123+
end
124+
125+
function _retrieve!(rng, p, r::NumericRange{T}, X) where {T<:Real}
126+
return @. p = _transform(r.scale, X)
127+
end
128+
129+
_transform(::Symbol, X) = X
130+
131+
_transform(scale, X) = scale(X)

0 commit comments

Comments
 (0)