Skip to content

Commit 92c8a00

Browse files
committed
Add tests with multi-threading/processing
1 parent 6c796d7 commit 92c8a00

File tree

3 files changed

+44
-38
lines changed

3 files changed

+44
-38
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
23
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
34
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
45
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ParticleSwarmOptimization
22
using Random
33
using Test
4+
using ComputationalResources
45
using Distributions
56
using EvoTrees
67
using MLJBase

test/strategies/basic.jl

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,50 @@
1414
@test ps.prob_shift == 0.25
1515
end
1616

17-
@testset "EvoTree Tuning" begin
18-
rng = StableRNG(123)
19-
features = rand(rng, 10_000) .* 5 .- 2
20-
X = MLJBase.table(reshape(features, (size(features)[1], 1)))
21-
y = sin.(features) .* 0.5 .+ 0.5
22-
y = EvoTrees.logit(y) + randn(rng, size(y))
23-
y = EvoTrees.sigmoid(y)
17+
for acceleration in (CPU1(), CPUProcesses(), CPUThreads())
18+
@testset "EvoTree Tuning with $(typeof(acceleration))" begin
19+
rng = StableRNG(123)
20+
features = rand(rng, 10_000) .* 5 .- 2
21+
X = MLJBase.table(reshape(features, (size(features)[1], 1)))
22+
y = sin.(features) .* 0.5 .+ 0.5
23+
y = EvoTrees.logit(y) + randn(rng, size(y))
24+
y = EvoTrees.sigmoid(y)
2425

25-
tree = VERSION v"1.4" ? EvoTreeRegressor(rng=rng) : EvoTreeRegressor(seed=123)
26-
r1 = range(tree, :max_depth; values=[3:7;])
27-
r2 = range(tree, ; lower=-2, upper=0, scale=exp10)
26+
tree = EvoTreeRegressor(rng=rng)
27+
r1 = range(tree, :max_depth; values=[3:7;])
28+
r2 = range(tree, ; lower=-2, upper=0, scale=exp10)
2829

29-
baseline_self_tuning_tree = TunedModel(
30-
model=tree,
31-
tuning=RandomSearch(rng=StableRNG(1234)),
32-
# tuning=ParticleSwarm(n_particles=3, rng=rng),
33-
resampling=CV(nfolds=5, rng=StableRNG(8888)),
34-
range=[r1, r2],
35-
measure=(ŷ, y) -> mean(abs.(ŷ .- y)),
36-
n=15
37-
)
38-
baseline_mach = machine(baseline_self_tuning_tree, X, y)
39-
fit!(baseline_mach, verbosity=0)
40-
baseline_rep = report(baseline_mach)
41-
baseline_best_loss = baseline_rep.best_history_entry.measurement[1]
30+
baseline_self_tuning_tree = TunedModel(
31+
model=tree,
32+
tuning=RandomSearch(rng=StableRNG(1234)),
33+
# tuning=ParticleSwarm(n_particles=3, rng=rng),
34+
resampling=CV(nfolds=5, rng=StableRNG(8888)),
35+
range=[r1, r2],
36+
measure=(ŷ, y) -> mean(abs.(ŷ .- y)),
37+
n=15,
38+
acceleration=acceleration
39+
)
40+
baseline_mach = machine(baseline_self_tuning_tree, X, y)
41+
fit!(baseline_mach, verbosity=2)
42+
baseline_rep = report(baseline_mach)
43+
baseline_best_loss = baseline_rep.best_history_entry.measurement[1]
4244

43-
self_tuning_tree = TunedModel(
44-
model=tree,
45-
tuning=ParticleSwarm(rng=StableRNG(1234)),
46-
# tuning=ParticleSwarm(n_particles=3, rng=rng),
47-
resampling=CV(nfolds=5, rng=StableRNG(8888)),
48-
range=[r1, r2],
49-
measure=(ŷ, y) -> mean(abs.(ŷ .- y)),
50-
n=15
51-
)
52-
mach = machine(self_tuning_tree, X, y)
53-
fit!(mach, verbosity=0)
54-
rep = report(mach)
55-
best_loss = rep.best_history_entry.measurement[1]
45+
self_tuning_tree = TunedModel(
46+
model=tree,
47+
tuning=ParticleSwarm(rng=StableRNG(1234)),
48+
resampling=CV(nfolds=5, rng=StableRNG(8888)),
49+
range=[r1, r2],
50+
measure=(ŷ, y) -> mean(abs.(ŷ .- y)),
51+
n=15,
52+
acceleration=acceleration
53+
)
54+
mach = machine(self_tuning_tree, X, y)
55+
fit!(mach, verbosity=2)
56+
rep = report(mach)
57+
best_loss = rep.best_history_entry.measurement[1]
5658

57-
# Compare with random search result with the same settings
58-
@test best_loss < baseline_best_loss || isapprox(best_loss, baseline_best_loss, 1e-3)
59+
# Compare with random search result with the same settings
60+
@test best_loss < baseline_best_loss ||
61+
isapprox(best_loss, baseline_best_loss, 1e-3)
62+
end
5963
end

0 commit comments

Comments
 (0)