Skip to content

Commit 7b5c4d1

Browse files
committed
tests for subset selection
1 parent f4e9d76 commit 7b5c4d1

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

test/subset_selection.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,37 @@
55
using UnicodePlots
66
using Zygote
77

8-
b = SubsetSelectionBenchmark()
8+
n = 25
9+
k = 5
910

10-
dataset = generate_dataset(b, 500)
11-
model = generate_statistical_model(b)
12-
maximizer = generate_maximizer(b)
13-
14-
# train_dataset, test_dataset = dataset[1:450], dataset[451:500]
15-
# X_train = train_dataset.features
16-
# Y_train = train_dataset.solutions
17-
18-
# perturbed_maximizer = PerturbedAdditive(maximizer; ε=1.0, nb_samples=100)
19-
# loss = FenchelYoungLoss(perturbed_maximizer)
11+
b = SubsetSelectionBenchmark(; n=n, k=k)
2012

21-
# starting_gap = compute_gap(b, test_dataset, model, maximizer)
13+
io = IOBuffer()
14+
show(io, b)
15+
@test String(take!(io)) == "SubsetSelectionBenchmark(n=25, k=5)"
2216

23-
# opt_state = Flux.setup(Adam(0.1), model)
24-
# loss_history = Float64[]
25-
# for epoch in 1:50
26-
# val, grads = Flux.withgradient(model) do m
27-
# sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset)
28-
# end
29-
# Flux.update!(opt_state, model, grads[1])
30-
# push!(loss_history, val)
31-
# end
32-
33-
# final_gap = compute_gap(b, test_dataset, model, maximizer)
17+
dataset = generate_dataset(b, 50)
18+
model = generate_statistical_model(b)
19+
maximizer = generate_maximizer(b)
3420

35-
# lineplot(loss_history)
36-
# @test loss_history[end] < loss_history[1]
37-
# @test final_gap < starting_gap / 10
21+
for (i, sample) in enumerate(dataset)
22+
x = sample.x
23+
θ_true = sample.θ
24+
y_true = sample.y
25+
@test size(x) == (n,)
26+
@test length(θ_true) == n
27+
@test length(y_true) == n
28+
@test isnothing(sample.instance)
29+
@test all(y_true .== maximizer(θ_true))
30+
31+
# Features and true weights should be equal
32+
@test all(θ_true .== x)
33+
34+
θ = model(x)
35+
@test length(θ) == n
36+
37+
y = maximizer(θ)
38+
@test length(y) == n
39+
@test sum(y) == k
40+
end
3841
end

0 commit comments

Comments
 (0)