Skip to content

Commit 75703c8

Browse files
committed
Update tests
1 parent ae830db commit 75703c8

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

test/parameters.jl

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,61 @@
11
@testset "Parameters" begin
2+
# Swarm size
23
n = 3
4+
# Test ranges
35
r1 = range(String, :r1; values=["a", "b", "c"])
46
r2 = range(Int, :r2; lower=1, upper=3, scale=exp10)
57
r3 = range(Float64, :r3; lower=1, upper=Inf, origin=4, unit=1)
68
r4 = range(Float32, :r4; lower=-Inf, upper=Inf, origin=0, unit=1)
9+
rs = (r1, r2, r3, r4)
10+
# Test distribution types
11+
Ds = (Dirichlet, Uniform, Gamma, Normal)
12+
# Manually fitted distributions for test ranges
713
d1 = Dirichlet(ones(3))
814
d2 = Uniform(1, 3)
915
d3 = truncated(Gamma(16, 0.25), 1, Inf)
1016
d4 = Normal(0, 1)
17+
ds = (d1, d2, d3, d4)
18+
# Test range representation lengths
19+
lengths = (3, 1, 1, 1)
20+
# Test range's corresponding indices in internal representation state.X
21+
indices = (1:3, 4, 5, 6)
22+
# Generated samples from d1, d2, d3, d4 with StableRNG(1234)
1123
X1 = [0.14280010160187237 0.49409071076694583 0.3631091876311819;
1224
0.295010672512568 0.4534584876713112 0.25153083981612073;
1325
0.12881930462550284 0.2617407494915029 0.6094399458829942]
1426
X2 = [2.7429797605672808, 2.3976392099947, 2.5742724788985445]
1527
X3 = [3.9372495283243105, 3.6569395920512977, 3.6354556967115146]
1628
X4 = [-0.8067647083847199, 0.420991611378423, 0.6736019046580138]
29+
Xs = (X1, X2, X3, X4)
1730

1831
@testset "Initializer" begin
19-
@test PSO._initializer(r1) === Dirichlet
20-
@test PSO._initializer(r2) === Uniform
21-
@test PSO._initializer(r3) === Gamma
22-
@test PSO._initializer(r4) === Normal
32+
for (r, D) in zip(rs, Ds)
33+
@test PSO._initializer(r) === D
34+
end
2335
end
2436

2537
@testset "Initialize with distribution types" begin
2638
rng = StableRNG(1234)
27-
@test PSO._initialize(rng, r1, Dirichlet, n)[[1,3,4]] == (r1, 3, X1)
28-
@test PSO._initialize(rng, r2, Uniform, n)[[1,3,4]] == (r2, 1, X2)
29-
@test PSO._initialize(rng, r3, Gamma, n)[[1,3,4]] == (r3, 1, X3)
30-
@test PSO._initialize(rng, r4, Normal, n)[[1,3,4]] == (r4, 1, X4)
39+
for (r, D, l, X) in zip(rs, Ds, lengths, Xs)
40+
r̂, l̂, X̂ = PSO._initialize(rng, r, D, n)[[1,3,4]]
41+
@test=== r
42+
@test== l
43+
@test X
44+
end
3145
end
3246

3347
@testset "Initialize with distributions" begin
3448
rng = StableRNG(1234)
35-
@test PSO._initialize(rng, r1, d1, n)[[1,3,4]] == (r1, 3, X1)
36-
@test PSO._initialize(rng, r2, d2, n)[[1,3,4]] == (r2, 1, X2)
37-
@test PSO._initialize(rng, r3, d3, n)[[1,3,4]] == (r3, 1, X3)
38-
@test PSO._initialize(rng, r4, d4, n)[[1,3,4]] == (r4, 1, X4)
49+
for (r, d, l, X) in zip(rs, ds, lengths, Xs)
50+
r̂, l̂, X̂ = PSO._initialize(rng, r, d, n)[[1,3,4]]
51+
@test=== r
52+
@test== l
53+
@test X
54+
end
3955
end
4056

4157
@testset "Range Indices" begin
42-
@test PSO._to_indices((3,1,1,1)) == (1:3, 4, 5, 6)
58+
@test PSO._to_indices(lengths) == indices
4359
end
4460

4561
@testset "Unsupported distributions" begin
@@ -52,24 +68,26 @@
5268
end
5369

5470
@testset "Initialize one range" begin
55-
ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234))
56-
state = PSO.initialize(r1, ps)
57-
@test state.ranges == (r1,)
58-
@test state.indices == (1:3,)
59-
@test state.X == X1
71+
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
72+
for (r, l, i, X) in zip(rs, lengths, indices, Xs)
73+
state = PSO.initialize(r, ps)
74+
@test state.ranges == (r,)
75+
@test state.indices == (l == 1 ? 1 : 1:l,)
76+
@test state.X X
77+
end
6078
end
6179

6280
@testset "Initialize multiple ranges" begin
63-
ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234))
81+
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
6482
ranges = [r1, (r2, Uniform), (r3, d3), r4]
6583
state = PSO.initialize(ranges, ps)
66-
@test state.ranges == (r1, r2, r3, r4)
67-
@test state.indices == (1:3, 4, 5, 6)
68-
@test state.X == hcat(X1, X2, X3, X4)
84+
@test state.ranges == rs
85+
@test state.indices == indices
86+
@test state.X hcat(Xs...)
6987
end
7088

7189
@testset "Retrieve parameters" begin
72-
ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234))
90+
ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234))
7391
ranges = [r1, (r2, Uniform), (r3, d3), r4]
7492
state = PSO.initialize(ranges, ps)
7593
PSO.retrieve!(state, ps)

0 commit comments

Comments
 (0)