Skip to content

Commit c379866

Browse files
committed
Update and fix tests
1 parent 17883d2 commit c379866

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

src/Argmax2D/Argmax2D.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,14 @@ function Utils.generate_dataset(
5050
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
5151
)
5252
(; nb_features, encoder, polytope_vertex_range) = bench
53-
X = [randn(rng, nb_features) for _ in 1:dataset_size]
54-
θs = encoder.(X)
55-
θs ./= 2 * norm.(θs)
56-
instances = [
57-
build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng)) for
58-
_ in 1:dataset_size
59-
]
60-
Y = [maximizer(θ; instance) for (θ, instance) in zip(θs, instances)]
61-
return [
62-
DataSample(; x, θ_true, y_true, instance) for
63-
(x, θ_true, y_true, instance) in zip(X, θs, Y, instances)
64-
]
53+
return map(1:dataset_size) do _
54+
x = randn(rng, nb_features)
55+
θ_true = encoder(x)
56+
θ_true ./= 2 * norm(θ_true)
57+
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
58+
y_true = maximizer(θ_true; instance)
59+
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
60+
end
6561
end
6662

6763
Utils.generate_maximizer(::Argmax2DBenchmark) = maximizer

test/argmax_2d.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
@testitem "Argmax2D" begin
22
using DecisionFocusedLearningBenchmarks
33

4-
instance_dim = 10
54
nb_features = 5
6-
7-
b = ArgmaxBenchmark(; instance_dim=instance_dim, nb_features=nb_features)
5+
b = Argmax2DBenchmark(; nb_features=nb_features)
86

97
io = IOBuffer()
108
show(io, b)
11-
@test String(take!(io)) == "ArgmaxBenchmark(instance_dim=10, nb_features=5)"
9+
@test String(take!(io)) == "Argmax2DBenchmark(nb_features=5)"
1210

1311
dataset = generate_dataset(b, 50)
1412
model = generate_statistical_model(b)
1513
maximizer = generate_maximizer(b)
1614

1715
for (i, sample) in enumerate(dataset)
18-
(; x, θ_true, y_true) = sample
19-
@test size(x) == (nb_features, instance_dim)
20-
@test length(θ_true) == instance_dim
21-
@test length(y_true) == instance_dim
22-
@test isnothing(sample.instance)
23-
@test all(y_true .== maximizer(θ_true))
16+
(; x, θ_true, y_true, instance) = sample
17+
@test length(x) == nb_features
18+
@test length(θ_true) == 2 # 2D vectors
19+
@test length(y_true) == 2 # 2D point
20+
@test !isnothing(sample.instance) # instance is a polytope
21+
@test instance isa Vector{Vector{Float64}} # polytope is vector of 2D points
22+
@test all(length(vertex) == 2 for vertex in instance) # all vertices are 2D
23+
@test y_true in instance # solution should be a vertex of the polytope
24+
@test y_true == maximizer(θ_true; instance=instance)
2425

2526
θ = model(x)
26-
@test length(θ) == instance_dim
27+
@test length(θ) == 2 # 2D vector
2728

28-
y = maximizer(θ)
29-
@test length(y) == instance_dim
29+
y = maximizer(θ; instance=instance)
30+
@test length(y) == 2 # 2D point
31+
@test y in instance # solution should be a vertex of the polytope
3032
end
3133
end

0 commit comments

Comments
 (0)