Skip to content

Commit 0d96fde

Browse files
committed
using gauntlet in unit tests
1 parent 153bba6 commit 0d96fde

File tree

3 files changed

+46
-20
lines changed

3 files changed

+46
-20
lines changed

test/gauntlet/experiments.jl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function experiment_focused()
8888
end
8989

9090

91-
function experiment_exponential()
91+
function experiment_exponential(faster::Bool)
9292
graph = [TravelGraph.cycle, TravelGraph.complete,]
9393
count = [TravelRateCount.destination, TravelRateCount.pair,]
9494
sampler_spec = [
@@ -97,7 +97,8 @@ function experiment_exponential()
9797
DirectMethod(:remove, :array),
9898
]
9999
state_cnt = [4, 5, 6]
100-
configurations = full_factorial(
100+
design = faster ? all_pairs : full_factorial
101+
configurations = design(
101102
graph, count, collect(1:length(sampler_spec)),
102103
state_cnt
103104
)
@@ -115,7 +116,7 @@ function experiment_exponential()
115116
return arrangements
116117
end
117118

118-
function experiment_range()
119+
function experiment_range(faster::Bool)
119120
memory = [TravelMemory.forget, TravelMemory.remember,]
120121
graph = [TravelGraph.cycle, TravelGraph.complete,]
121122
dist = [TravelRateDist.exponential, TravelRateDist.general,]
@@ -125,7 +126,8 @@ function experiment_range()
125126
FirstReactionMethod(), FirstToFireMethod(), NextReactionMethod(),
126127
]
127128
state_cnt = [4, 5, 6]
128-
configurations = full_factorial(
129+
design = faster ? all_pairs : full_factorial
130+
configurations = design(
129131
memory, graph, dist, count, delay, collect(1:length(sampler_spec)),
130132
state_cnt
131133
)
@@ -141,17 +143,18 @@ function experiment_range()
141143
end
142144

143145

144-
function run_experiments()
146+
function run_experiments(faster=false)
147+
echo = !faster
145148
rng_single = Xoshiro(882342987)
146-
configurations1 = experiment_range()
147-
configurations2 = experiment_exponential()
149+
configurations1 = experiment_range(faster)
150+
configurations2 = experiment_exponential(faster)
148151
configurations = vcat(configurations2, configurations1)
149152
# configurations = experiment_focused()
150-
println("There are $(length(configurations)) configurations.")
153+
echo && println("There are $(length(configurations)) configurations.")
151154
results = Vector{Any}(undef, length(configurations))
152155
for gen_idx in eachindex(configurations)
153156
sampler_spec, sut = configurations[gen_idx]
154-
println("spec $sampler_spec sut $sut")
157+
echo && println("spec $sampler_spec sut $sut")
155158
results[gen_idx] = collect_data_single(sampler_spec, sut, rng_single)
156159
end
157160
scores = Vector{Tuple{Float64,Int}}(undef, length(results))
@@ -165,27 +168,39 @@ function run_experiments()
165168
scores[score_idx] = (adjusted, score_idx)
166169
end
167170
sort!(scores)
168-
println("=" ^ 80)
169-
println("lowest scores")
170-
println("=" ^ 80)
171+
echo && println("=" ^ 80)
172+
echo && println("lowest scores")
173+
echo && println("=" ^ 80)
174+
succeed = true
171175
for examine in 1:5
172176
value, config_idx = scores[examine]
173177
config = configurations[config_idx]
174-
println("value $value")
175-
println("config $config")
178+
echo && println("value $value")
179+
echo && println("config $config")
176180
res_metrics = results[config_idx]
181+
group = Dict{Tuple{String,Int},Vector{Float64}}()
177182
for res in res_metrics
178-
println("metric $(res.name) $(res.pvalue) $(res.clock) $(res.count)")
183+
echo && println("metric $(res.name) $(res.pvalue) $(res.clock) $(res.count)")
184+
group[(res.name, res.clock)] = [res.pvalue]
179185
end
180-
println("=" ^ 80)
186+
echo && println("=" ^ 80)
181187
sampler_spec, sut = configurations[config_idx]
182188
for i in 1:5
183189
rep_metrics = collect_data_single(sampler_spec, sut, rng_single)
184-
println("-"^80)
190+
echo && println("-"^80)
185191
for res in rep_metrics
186-
println("metric $(res.name) $(res.pvalue) $(res.clock) $(res.count)")
192+
echo && println("metric $(res.name) $(res.pvalue) $(res.clock) $(res.count)")
193+
push!(group[(res.name, res.clock)], res.pvalue)
187194
end
188195
end
189-
println("=" ^ 80)
196+
for (metid, metvals) in group
197+
m = median(metvals)
198+
@assert m > 0.05
199+
if m < 0.05
200+
succeed = false
201+
end
202+
end
203+
echo && println("=" ^ 80)
190204
end
205+
return succeed
191206
end

test/gauntlet/test_travel.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ using Random
44
using Distributions
55
using Graphs
66
using Base.Threads
7+
using HypothesisTests
78

89
include("travel.jl")
10+
using .TravelModel
911
include("generate_data.jl")
1012
include("mark_calibration.jl")
13+
include("running_score.jl")
14+
include("experiments.jl")
1115

12-
using .TravelModel
1316

1417

1518
# Helper to create a vector of RNGs for threaded operations
@@ -240,3 +243,10 @@ end
240243
@test isfinite(total)
241244
end
242245
end
246+
247+
248+
@testset "Experiments Integration Tests" begin
249+
@testset "run_experiments faster mode" begin
250+
@test run_experiments(true)
251+
end
252+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ all_tests = [
2929
"test_vas_integrate.jl",
3030
"test_vas.jl",
3131
"test_with_common_random.jl",
32+
"gauntlet/test_travel.jl",
3233
]
3334

3435
# Filter tests based on command-line arguments (ARGS)

0 commit comments

Comments
 (0)