@@ -13,11 +13,13 @@ Basic benchmark problem with an argmax as the CO algorithm.
1313# Fields
1414$TYPEDFIELDS
1515"""
16- struct ArgmaxBenchmark <: AbstractBenchmark
16+ struct ArgmaxBenchmark{E} <: AbstractBenchmark
1717 " instances dimension, total number of classes"
1818 instance_dim:: Int
1919 " number of features"
2020 nb_features:: Int
21+ " true mapping between features and costs"
22+ encoder:: E
2123end
2224
2325function Base. show (io:: IO , bench:: ArgmaxBenchmark )
@@ -27,8 +29,15 @@ function Base.show(io::IO, bench::ArgmaxBenchmark)
2729 )
2830end
2931
30- function ArgmaxBenchmark (; instance_dim:: Int = 10 , nb_features:: Int = 5 )
31- return ArgmaxBenchmark (instance_dim, nb_features)
32+ """
33+ $TYPEDSIGNATURES
34+
35+ Custom constructor for [`ArgmaxBenchmark`](@ref).
36+ """
37+ function ArgmaxBenchmark (; instance_dim:: Int = 10 , nb_features:: Int = 5 , seed= nothing )
38+ Random. seed! (seed)
39+ model = Chain (Dense (nb_features => 1 ; bias= false ), vec)
40+ return ArgmaxBenchmark (instance_dim, nb_features, model)
3241end
3342
3443"""
@@ -59,12 +68,10 @@ Generate a dataset of labeled instances for the argmax problem.
5968function Utils. generate_dataset (
6069 bench:: ArgmaxBenchmark , dataset_size:: Int = 10 ; seed:: Int = 0 , noise_std= 0.0
6170)
62- (; instance_dim, nb_features) = bench
71+ (; instance_dim, nb_features, encoder ) = bench
6372 rng = MersenneTwister (seed)
6473 features = [randn (rng, Float32, nb_features, instance_dim) for _ in 1 : dataset_size]
65- mapping = Chain (Dense (nb_features => 1 ; bias= false ), vec)
66- costs = mapping .(features)
67- # solutions = one_hot_argmax.(costs)
74+ costs = encoder .(features)
6875 noisy_solutions = [
6976 one_hot_argmax (θ + noise_std * randn (rng, Float32, instance_dim)) for θ in costs
7077 ]
@@ -79,9 +86,9 @@ $TYPEDSIGNATURES
7986
8087Initialize a linear model for `bench` using `Flux`.
8188"""
82- function Utils. generate_statistical_model (bench:: ArgmaxBenchmark ; seed= 0 )
83- Random. seed! (seed)
89+ function Utils. generate_statistical_model (bench:: ArgmaxBenchmark ; seed= nothing )
8490 (; nb_features) = bench
91+ Random. seed! (seed)
8592 return Chain (Dense (nb_features => 1 ; bias= false ), vec)
8693end
8794
0 commit comments