Skip to content

Commit f1efe57

Browse files
authored
Merge pull request #18 from JuliaDecisionFocusedLearning/argmax-encoder
Store true encoder in Argmax and Ranking benchmarks
2 parents 8748674 + 247a9b7 commit f1efe57

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

src/Argmax/Argmax.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ Basic benchmark problem with an argmax as the CO algorithm.
1313
# Fields
1414
$TYPEDFIELDS
1515
"""
16-
struct ArgmaxBenchmark <: AbstractBenchmark
16+
struct ArgmaxBenchmark{E} <: AbstractBenchmark
1717
"instances dimension, total number of classes"
1818
instance_dim::Int
1919
"number of features"
2020
nb_features::Int
21+
"true mapping between features and costs"
22+
encoder::E
2123
end
2224

2325
function Base.show(io::IO, bench::ArgmaxBenchmark)
@@ -27,8 +29,15 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
2729
)
2830
end
2931

30-
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5)
31-
return ArgmaxBenchmark(instance_dim, nb_features)
32+
"""
33+
$TYPEDSIGNATURES
34+
35+
Custom constructor for [`ArgmaxBenchmark`](@ref).
36+
"""
37+
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
38+
Random.seed!(seed)
39+
model = Chain(Dense(nb_features => 1; bias=false), vec)
40+
return ArgmaxBenchmark(instance_dim, nb_features, model)
3241
end
3342

3443
"""
@@ -59,12 +68,10 @@ Generate a dataset of labeled instances for the argmax problem.
5968
function Utils.generate_dataset(
6069
bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
6170
)
62-
(; instance_dim, nb_features) = bench
71+
(; instance_dim, nb_features, encoder) = bench
6372
rng = MersenneTwister(seed)
6473
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
65-
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
66-
costs = mapping.(features)
67-
# solutions = one_hot_argmax.(costs)
74+
costs = encoder.(features)
6875
noisy_solutions = [
6976
one_hot_argmax+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
7077
]
@@ -79,9 +86,9 @@ $TYPEDSIGNATURES
7986
8087
Initialize a linear model for `bench` using `Flux`.
8188
"""
82-
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=0)
83-
Random.seed!(seed)
89+
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=nothing)
8490
(; nb_features) = bench
91+
Random.seed!(seed)
8592
return Chain(Dense(nb_features => 1; bias=false), vec)
8693
end
8794

src/Ranking/Ranking.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ Basic benchmark problem with ranking as the CO algorithm.
1313
# Fields
1414
$TYPEDFIELDS
1515
"""
16-
struct RankingBenchmark <: AbstractBenchmark
16+
struct RankingBenchmark{E} <: AbstractBenchmark
1717
"instances dimension, total number of classes"
1818
instance_dim::Int
1919
"number of features"
2020
nb_features::Int
21+
"true mapping between features and costs"
22+
encoder::E
2123
end
2224

2325
function Base.show(io::IO, bench::RankingBenchmark)
@@ -27,8 +29,15 @@ function Base.show(io::IO, bench::RankingBenchmark)
2729
)
2830
end
2931

30-
function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5)
31-
return RankingBenchmark(instance_dim, nb_features)
32+
"""
33+
$TYPEDSIGNATURES
34+
35+
Custom constructor for [`RankingBenchmark`](@ref).
36+
"""
37+
function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing)
38+
Random.seed!(seed)
39+
model = Chain(Dense(nb_features => 1; bias=false), vec)
40+
return RankingBenchmark(instance_dim, nb_features, model)
3241
end
3342

3443
"""
@@ -57,12 +66,10 @@ Generate a dataset of labeled instances for the ranking problem.
5766
function Utils.generate_dataset(
5867
bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
5968
)
60-
(; instance_dim, nb_features) = bench
69+
(; instance_dim, nb_features, encoder) = bench
6170
rng = MersenneTwister(seed)
6271
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
63-
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
64-
costs = mapping.(features)
65-
# solutions = ranking.(costs)
72+
costs = encoder.(features)
6673
noisy_solutions = [
6774
ranking.+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
6875
]
@@ -77,9 +84,9 @@ $TYPEDSIGNATURES
7784
7885
Initialize a linear model for `bench` using `Flux`.
7986
"""
80-
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=0)
81-
Random.seed!(seed)
87+
function Utils.generate_statistical_model(bench::RankingBenchmark; seed=nothing)
8288
(; nb_features) = bench
89+
Random.seed!(seed)
8390
return Chain(Dense(nb_features => 1; bias=false), vec)
8491
end
8592

0 commit comments

Comments
 (0)