|
1 | 1 | @testset "Parameters" begin
|
| 2 | + # Swarm size |
2 | 3 | n = 3
|
| 4 | + # Test ranges |
3 | 5 | r1 = range(String, :r1; values=["a", "b", "c"])
|
4 | 6 | r2 = range(Int, :r2; lower=1, upper=3, scale=exp10)
|
5 | 7 | r3 = range(Float64, :r3; lower=1, upper=Inf, origin=4, unit=1)
|
6 | 8 | r4 = range(Float32, :r4; lower=-Inf, upper=Inf, origin=0, unit=1)
|
| 9 | + rs = (r1, r2, r3, r4) |
| 10 | + # Test distribution types |
| 11 | + Ds = (Dirichlet, Uniform, Gamma, Normal) |
| 12 | + # Manually fitted distributions for test ranges |
7 | 13 | d1 = Dirichlet(ones(3))
|
8 | 14 | d2 = Uniform(1, 3)
|
9 | 15 | d3 = truncated(Gamma(16, 0.25), 1, Inf)
|
10 | 16 | d4 = Normal(0, 1)
|
| 17 | + ds = (d1, d2, d3, d4) |
| 18 | + # Test range representation lengths |
| 19 | + lengths = (3, 1, 1, 1) |
| 20 | + # Test range's corresponding indices in internal representation state.X |
| 21 | + indices = (1:3, 4, 5, 6) |
| 22 | + # Generated samples from d1, d2, d3, d4 with StableRNG(1234) |
11 | 23 | X1 = [0.14280010160187237 0.49409071076694583 0.3631091876311819;
|
12 | 24 | 0.295010672512568 0.4534584876713112 0.25153083981612073;
|
13 | 25 | 0.12881930462550284 0.2617407494915029 0.6094399458829942]
|
14 | 26 | X2 = [2.7429797605672808, 2.3976392099947, 2.5742724788985445]
|
15 | 27 | X3 = [3.9372495283243105, 3.6569395920512977, 3.6354556967115146]
|
16 | 28 | X4 = [-0.8067647083847199, 0.420991611378423, 0.6736019046580138]
|
| 29 | + Xs = (X1, X2, X3, X4) |
17 | 30 |
|
18 | 31 | @testset "Initializer" begin
|
19 |
| - @test PSO._initializer(r1) === Dirichlet |
20 |
| - @test PSO._initializer(r2) === Uniform |
21 |
| - @test PSO._initializer(r3) === Gamma |
22 |
| - @test PSO._initializer(r4) === Normal |
| 32 | + for (r, D) in zip(rs, Ds) |
| 33 | + @test PSO._initializer(r) === D |
| 34 | + end |
23 | 35 | end
|
24 | 36 |
|
25 | 37 | @testset "Initialize with distribution types" begin
|
26 | 38 | rng = StableRNG(1234)
|
27 |
| - @test PSO._initialize(rng, r1, Dirichlet, n)[[1,3,4]] == (r1, 3, X1) |
28 |
| - @test PSO._initialize(rng, r2, Uniform, n)[[1,3,4]] == (r2, 1, X2) |
29 |
| - @test PSO._initialize(rng, r3, Gamma, n)[[1,3,4]] == (r3, 1, X3) |
30 |
| - @test PSO._initialize(rng, r4, Normal, n)[[1,3,4]] == (r4, 1, X4) |
| 39 | + for (r, D, l, X) in zip(rs, Ds, lengths, Xs) |
| 40 | + r̂, l̂, X̂ = PSO._initialize(rng, r, D, n)[[1,3,4]] |
| 41 | + @test r̂ === r |
| 42 | + @test l̂ == l |
| 43 | + @test X̂ ≈ X |
| 44 | + end |
31 | 45 | end
|
32 | 46 |
|
33 | 47 | @testset "Initialize with distributions" begin
|
34 | 48 | rng = StableRNG(1234)
|
35 |
| - @test PSO._initialize(rng, r1, d1, n)[[1,3,4]] == (r1, 3, X1) |
36 |
| - @test PSO._initialize(rng, r2, d2, n)[[1,3,4]] == (r2, 1, X2) |
37 |
| - @test PSO._initialize(rng, r3, d3, n)[[1,3,4]] == (r3, 1, X3) |
38 |
| - @test PSO._initialize(rng, r4, d4, n)[[1,3,4]] == (r4, 1, X4) |
| 49 | + for (r, d, l, X) in zip(rs, ds, lengths, Xs) |
| 50 | + r̂, l̂, X̂ = PSO._initialize(rng, r, d, n)[[1,3,4]] |
| 51 | + @test r̂ === r |
| 52 | + @test l̂ == l |
| 53 | + @test X̂ ≈ X |
| 54 | + end |
39 | 55 | end
|
40 | 56 |
|
41 | 57 | @testset "Range Indices" begin
|
42 |
| - @test PSO._to_indices((3,1,1,1)) == (1:3, 4, 5, 6) |
| 58 | + @test PSO._to_indices(lengths) == indices |
43 | 59 | end
|
44 | 60 |
|
45 | 61 | @testset "Unsupported distributions" begin
|
|
52 | 68 | end
|
53 | 69 |
|
54 | 70 | @testset "Initialize one range" begin
|
55 |
| - ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234)) |
56 |
| - state = PSO.initialize(r1, ps) |
57 |
| - @test state.ranges == (r1,) |
58 |
| - @test state.indices == (1:3,) |
59 |
| - @test state.X == X1 |
| 71 | + ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234)) |
| 72 | + for (r, l, i, X) in zip(rs, lengths, indices, Xs) |
| 73 | + state = PSO.initialize(r, ps) |
| 74 | + @test state.ranges == (r,) |
| 75 | + @test state.indices == (l == 1 ? 1 : 1:l,) |
| 76 | + @test state.X ≈ X |
| 77 | + end |
60 | 78 | end
|
61 | 79 |
|
62 | 80 | @testset "Initialize multiple ranges" begin
|
63 |
| - ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234)) |
| 81 | + ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234)) |
64 | 82 | ranges = [r1, (r2, Uniform), (r3, d3), r4]
|
65 | 83 | state = PSO.initialize(ranges, ps)
|
66 |
| - @test state.ranges == (r1, r2, r3, r4) |
67 |
| - @test state.indices == (1:3, 4, 5, 6) |
68 |
| - @test state.X == hcat(X1, X2, X3, X4) |
| 84 | + @test state.ranges == rs |
| 85 | + @test state.indices == indices |
| 86 | + @test state.X ≈ hcat(Xs...) |
69 | 87 | end
|
70 | 88 |
|
71 | 89 | @testset "Retrieve parameters" begin
|
72 |
| - ps = ParticleSwarm(n_particles=3, rng=StableRNG(1234)) |
| 90 | + ps = ParticleSwarm(n_particles=n, rng=StableRNG(1234)) |
73 | 91 | ranges = [r1, (r2, Uniform), (r3, d3), r4]
|
74 | 92 | state = PSO.initialize(ranges, ps)
|
75 | 93 | PSO.retrieve!(state, ps)
|
|
0 commit comments