|
5 | 5 | using UnicodePlots |
6 | 6 | using Zygote |
7 | 7 |
|
8 | | - b = SubsetSelectionBenchmark() |
| 8 | + n = 25 |
| 9 | + k = 5 |
9 | 10 |
|
10 | | - dataset = generate_dataset(b, 500) |
11 | | - model = generate_statistical_model(b) |
12 | | - maximizer = generate_maximizer(b) |
13 | | - |
14 | | - # train_dataset, test_dataset = dataset[1:450], dataset[451:500] |
15 | | - # X_train = train_dataset.features |
16 | | - # Y_train = train_dataset.solutions |
17 | | - |
18 | | - # perturbed_maximizer = PerturbedAdditive(maximizer; ε=1.0, nb_samples=100) |
19 | | - # loss = FenchelYoungLoss(perturbed_maximizer) |
| 11 | + b = SubsetSelectionBenchmark(; n=n, k=k) |
20 | 12 |
|
21 | | - # starting_gap = compute_gap(b, test_dataset, model, maximizer) |
| 13 | + io = IOBuffer() |
| 14 | + show(io, b) |
| 15 | + @test String(take!(io)) == "SubsetSelectionBenchmark(n=25, k=5)" |
22 | 16 |
|
23 | | - # opt_state = Flux.setup(Adam(0.1), model) |
24 | | - # loss_history = Float64[] |
25 | | - # for epoch in 1:50 |
26 | | - # val, grads = Flux.withgradient(model) do m |
27 | | - # sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset) |
28 | | - # end |
29 | | - # Flux.update!(opt_state, model, grads[1]) |
30 | | - # push!(loss_history, val) |
31 | | - # end |
32 | | - |
33 | | - # final_gap = compute_gap(b, test_dataset, model, maximizer) |
| 17 | + dataset = generate_dataset(b, 50) |
| 18 | + model = generate_statistical_model(b) |
| 19 | + maximizer = generate_maximizer(b) |
34 | 20 |
|
35 | | - # lineplot(loss_history) |
36 | | - # @test loss_history[end] < loss_history[1] |
37 | | - # @test final_gap < starting_gap / 10 |
| 21 | + for (i, sample) in enumerate(dataset) |
| 22 | + x = sample.x |
| 23 | + θ_true = sample.θ |
| 24 | + y_true = sample.y |
| 25 | + @test size(x) == (n,) |
| 26 | + @test length(θ_true) == n |
| 27 | + @test length(y_true) == n |
| 28 | + @test isnothing(sample.instance) |
| 29 | + @test all(y_true .== maximizer(θ_true)) |
| 30 | + |
| 31 | + # Features and true weights should be equal |
| 32 | + @test all(θ_true .== x) |
| 33 | + |
| 34 | + θ = model(x) |
| 35 | + @test length(θ) == n |
| 36 | + |
| 37 | + y = maximizer(θ) |
| 38 | + @test length(y) == n |
| 39 | + @test sum(y) == k |
| 40 | + end |
38 | 41 | end |
0 commit comments