Skip to content

Commit 7054ab5

Browse files
committed
updated deps
1 parent 029c81b commit 7054ab5

File tree

7 files changed

+49
-44
lines changed

7 files changed

+49
-44
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212

1313
[compat]
1414
Distributions = "0.25"
15+
MLJBase = "1.8"
1516
MLJTuning = "0.8"
1617
julia = "1"

src/MLJParticleSwarmOptimization.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ using MLJTuning
88

99
export ParticleSwarm, AdaptiveParticleSwarm
1010

11-
include("swarm.jl")
11+
include("interface.jl")
1212
include("parameters.jl")
1313
include("update.jl")
14-
include("strategies/abstract.jl")
1514
include("strategies/basic.jl")
1615
include("strategies/adaptive.jl")
1716

src/interface.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
abstract type AbstractParticleSwarm <: MLJTuning.TuningStrategy end
2+
3+
struct ParticleSwarmState{T, R, P, I}
4+
ranges::R
5+
parameters::P
6+
indices::I
7+
X::Matrix{T}
8+
V::Matrix{T}
9+
pbest_X::Matrix{T}
10+
gbest_X::Matrix{T}
11+
pbest::Vector{T}
12+
gbest::Vector{T}
13+
end
14+
15+
mutable struct ParticleSwarm <: AbstractParticleSwarm
16+
n_particles::Integer
17+
w::Float64
18+
c1::Float64
19+
c2::Float64
20+
prob_shift::Float64
21+
rng::AbstractRNG
22+
# TODO: topology
23+
end
24+
25+
mutable struct AdaptiveParticleSwarm <: AbstractParticleSwarm
26+
n_particles::Integer
27+
c1::Float64
28+
c2::Float64
29+
prob_shift::Float64
30+
rng::AbstractRNG
31+
end
32+
33+
get_n_particles(tuning::AbstractParticleSwarm) = tuning.n_particles
34+
get_prob_shift(tuning::AbstractParticleSwarm) = tuning.prob_shift
35+
get_rng(tuning::AbstractParticleSwarm) = tuning.rng
36+
37+
function initialize(r, tuning::AbstractParticleSwarm)
38+
return initialize(get_rng(tuning), r, get_n_particles(tuning))
39+
end
40+
41+
function retrieve!(state::ParticleSwarmState, tuning::AbstractParticleSwarm)
42+
return retrieve!(get_rng(tuning), state)
43+
end
44+
45+
function pbest!(state::ParticleSwarmState, measurements, tuning::AbstractParticleSwarm)
46+
return pbest!(state, measurements, get_prob_shift(tuning))
47+
end

src/strategies/abstract.jl

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/strategies/adaptive.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,6 @@ copy of `model`. If the corresponding range has a specified `scale` function,
4040
then the transformation is applied before the hyperparameter is returned. If
4141
`scale` is a symbol (eg, `:log`), it is ignored.
4242
"""
43-
mutable struct AdaptiveParticleSwarm{R<:AbstractRNG} <: AbstractParticleSwarm
44-
n_particles::Int
45-
c1::Float64
46-
c2::Float64
47-
prob_shift::Float64
48-
rng::R
49-
end
50-
51-
# Constructor
5243

5344
function AdaptiveParticleSwarm(;
5445
n_particles=3,

src/strategies/basic.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,6 @@ where pₛ is the probability of the sampled hyperparameter value. For more
103103
information, refer to "A New Discrete Particle Swarm Optimization Algorithm" by
104104
Strasser, Goodman, Sheppard, and Butcher.
105105
"""
106-
mutable struct ParticleSwarm{R<:AbstractRNG} <: AbstractParticleSwarm
107-
n_particles::Int
108-
w::Float64
109-
c1::Float64
110-
c2::Float64
111-
prob_shift::Float64
112-
rng::R
113-
# TODO: topology
114-
end
115106

116107
function ParticleSwarm(;
117108
n_particles=3,

src/swarm.jl

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)