Skip to content

Commit 7e26cc0

Browse files
committed
Merge branch 'main' into StoVSP
2 parents d1c4cd4 + f1efe57 commit 7e26cc0

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

src/Argmax/Argmax.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@ using Random
88
"""
99
$TYPEDEF
1010
11-
Benchmark problem with an argmax as the CO algorithm.
11+
Basic benchmark problem with an argmax as the CO algorithm.
1212
1313
# Fields
1414
$TYPEDFIELDS
1515
"""
16-
struct ArgmaxBenchmark <: AbstractBenchmark
17-
"iinstances dimension, total number of classes"
16+
struct ArgmaxBenchmark{E} <: AbstractBenchmark
17+
"instances dimension, total number of classes"
1818
instance_dim::Int
1919
"number of features"
2020
nb_features::Int
21+
"true mapping between features and costs"
22+
encoder::E
2123
end
2224

2325
function Base.show(io::IO, bench::ArgmaxBenchmark)
@@ -27,8 +29,15 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
2729
)
2830
end
2931

30-
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5)
31-
return ArgmaxBenchmark(instance_dim, nb_features)
32+
"""
33+
$TYPEDSIGNATURES
34+
35+
Custom constructor for [`ArgmaxBenchmark`](@ref).
36+
"""
37+
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
38+
Random.seed!(seed)
39+
model = Chain(Dense(nb_features => 1; bias=false), vec)
40+
return ArgmaxBenchmark(instance_dim, nb_features, model)
3241
end
3342

3443
"""
@@ -45,7 +54,7 @@ end
4554
"""
4655
$TYPEDSIGNATURES
4756
48-
Return a top k maximizer.
57+
Return an argmax maximizer.
4958
"""
5059
function Utils.generate_maximizer(bench::ArgmaxBenchmark)
5160
return one_hot_argmax
@@ -54,19 +63,21 @@ end
5463
"""
5564
$TYPEDSIGNATURES
5665
57-
Generate a dataset of labeled instances for the subset selection problem.
58-
The mapping between features and cost is identity.
66+
Generate a dataset of labeled instances for the argmax problem.
5967
"""
60-
function Utils.generate_dataset(bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0)
61-
(; instance_dim, nb_features) = bench
68+
function Utils.generate_dataset(
69+
bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
70+
)
71+
(; instance_dim, nb_features, encoder) = bench
6272
rng = MersenneTwister(seed)
6373
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
64-
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
65-
costs = mapping.(features)
66-
solutions = one_hot_argmax.(costs)
74+
costs = encoder.(features)
75+
noisy_solutions = [
76+
one_hot_argmax+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
77+
]
6778
return [
6879
DataSample(; x, θ_true, y_true) for
69-
(x, θ_true, y_true) in zip(features, costs, solutions)
80+
(x, θ_true, y_true) in zip(features, costs, noisy_solutions)
7081
]
7182
end
7283

@@ -75,9 +86,9 @@ $TYPEDSIGNATURES
7586
7687
Initialize a linear model for `bench` using `Flux`.
7788
"""
78-
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=0)
79-
Random.seed!(seed)
89+
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=nothing)
8090
(; nb_features) = bench
91+
Random.seed!(seed)
8192
return Chain(Dense(nb_features => 1; bias=false), vec)
8293
end
8394

src/Ranking/Ranking.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@ using Random
88
"""
99
$TYPEDEF
1010
11-
Benchmark problem with an argmax as the CO algorithm.
11+
Basic benchmark problem with ranking as the CO algorithm.
1212
1313
# Fields
1414
$TYPEDFIELDS
1515
"""
16-
struct RankingBenchmark <: AbstractBenchmark
17-
"iinstances dimension, total number of classes"
16+
struct RankingBenchmark{E} <: AbstractBenchmark
17+
"instances dimension, total number of classes"
1818
instance_dim::Int
1919
"number of features"
2020
nb_features::Int
21+
"true mapping between features and costs"
22+
encoder::E
2123
end
2224

2325
function Base.show(io::IO, bench::RankingBenchmark)
@@ -27,8 +29,15 @@ function Base.show(io::IO, bench::RankingBenchmark)
2729
)
2830
end
2931

30-
function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5)
31-
return RankingBenchmark(instance_dim, nb_features)
32+
"""
33+
$TYPEDSIGNATURES
34+
35+
Custom constructor for [`RankingBenchmark`](@ref).
36+
"""
37+
function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
38+
Random.seed!(seed)
39+
model = Chain(Dense(nb_features => 1; bias=false), vec)
40+
return RankingBenchmark(instance_dim, nb_features, model)
3241
end
3342

3443
"""
@@ -43,7 +52,7 @@ end
4352
"""
4453
$TYPEDSIGNATURES
4554
46-
Return a top k maximizer.
55+
Return a ranking maximizer.
4756
"""
4857
function Utils.generate_maximizer(bench::RankingBenchmark)
4958
return ranking
@@ -52,19 +61,21 @@ end
5261
"""
5362
$TYPEDSIGNATURES
5463
55-
Generate a dataset of labeled instances for the subset selection problem.
56-
The mapping between features and cost is identity.
64+
Generate a dataset of labeled instances for the ranking problem.
5765
"""
58-
function Utils.generate_dataset(bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0)
59-
(; instance_dim, nb_features) = bench
66+
function Utils.generate_dataset(
67+
bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
68+
)
69+
(; instance_dim, nb_features, encoder) = bench
6070
rng = MersenneTwister(seed)
6171
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
62-
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
63-
costs = mapping.(features)
64-
solutions = ranking.(costs)
72+
costs = encoder.(features)
73+
noisy_solutions = [
74+
ranking.+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
75+
]
6576
return [
6677
DataSample(; x, θ_true, y_true) for
67-
(x, θ_true, y_true) in zip(features, costs, solutions)
78+
(x, θ_true, y_true) in zip(features, costs, noisy_solutions)
6879
]
6980
end
7081

@@ -73,9 +84,9 @@ $TYPEDSIGNATURES
7384
7485
Initialize a linear model for `bench` using `Flux`.
7586
"""
76-
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=0)
77-
Random.seed!(seed)
87+
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=nothing)
7888
(; nb_features) = bench
89+
Random.seed!(seed)
7990
return Chain(Dense(nb_features => 1; bias=false), vec)
8091
end
8192

0 commit comments

Comments
 (0)