diff --git a/src/Argmax/Argmax.jl b/src/Argmax/Argmax.jl index cf90899..0864a96 100644 --- a/src/Argmax/Argmax.jl +++ b/src/Argmax/Argmax.jl @@ -13,11 +13,13 @@ Basic benchmark problem with an argmax as the CO algorithm. # Fields $TYPEDFIELDS """ -struct ArgmaxBenchmark <: AbstractBenchmark +struct ArgmaxBenchmark{E} <: AbstractBenchmark "instances dimension, total number of classes" instance_dim::Int "number of features" nb_features::Int + "true mapping between features and costs" + encoder::E end function Base.show(io::IO, bench::ArgmaxBenchmark) @@ -27,8 +29,15 @@ function Base.show(io::IO, bench::ArgmaxBenchmark) ) end -function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5) - return ArgmaxBenchmark(instance_dim, nb_features) +""" +$TYPEDSIGNATURES + +Custom constructor for [`ArgmaxBenchmark`](@ref). +""" +function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing) + Random.seed!(seed) + model = Chain(Dense(nb_features => 1; bias=false), vec) + return ArgmaxBenchmark(instance_dim, nb_features, model) end """ @@ -59,12 +68,10 @@ Generate a dataset of labeled instances for the argmax problem. function Utils.generate_dataset( bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0 ) - (; instance_dim, nb_features) = bench + (; instance_dim, nb_features, encoder) = bench rng = MersenneTwister(seed) features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size] - mapping = Chain(Dense(nb_features => 1; bias=false), vec) - costs = mapping.(features) - # solutions = one_hot_argmax.(costs) + costs = encoder.(features) noisy_solutions = [ one_hot_argmax(θ + noise_std * randn(rng, Float32, instance_dim)) for θ in costs ] @@ -79,9 +86,9 @@ $TYPEDSIGNATURES Initialize a linear model for `bench` using `Flux`. """ -function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=0) - Random.seed!(seed) +function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=nothing) (; nb_features) = bench + Random.seed!(seed) return Chain(Dense(nb_features => 1; bias=false), vec) end diff --git a/src/Ranking/Ranking.jl b/src/Ranking/Ranking.jl index 2bb38dc..8b93b8a 100644 --- a/src/Ranking/Ranking.jl +++ b/src/Ranking/Ranking.jl @@ -13,11 +13,13 @@ Basic benchmark problem with ranking as the CO algorithm. # Fields $TYPEDFIELDS """ -struct RankingBenchmark <: AbstractBenchmark +struct RankingBenchmark{E} <: AbstractBenchmark "instances dimension, total number of classes" instance_dim::Int "number of features" nb_features::Int + "true mapping between features and costs" + encoder::E end function Base.show(io::IO, bench::RankingBenchmark) @@ -27,8 +29,15 @@ function Base.show(io::IO, bench::RankingBenchmark) ) end -function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5) - return RankingBenchmark(instance_dim, nb_features) +""" +$TYPEDSIGNATURES + +Custom constructor for [`RankingBenchmark`](@ref). +""" +function RankingBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothing) + Random.seed!(seed) + model = Chain(Dense(nb_features => 1; bias=false), vec) + return RankingBenchmark(instance_dim, nb_features, model) end """ @@ -57,12 +66,10 @@ Generate a dataset of labeled instances for the ranking problem. function Utils.generate_dataset( bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0 ) - (; instance_dim, nb_features) = bench + (; instance_dim, nb_features, encoder) = bench rng = MersenneTwister(seed) features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size] - mapping = Chain(Dense(nb_features => 1; bias=false), vec) - costs = mapping.(features) - # solutions = ranking.(costs) + costs = encoder.(features) noisy_solutions = [ ranking(θ .+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs ] @@ -77,9 +84,9 @@ $TYPEDSIGNATURES Initialize a linear model for `bench` using `Flux`. """ -function Utils.generate_statistical_model(bench::RankingBenchmark; seed=0) - Random.seed!(seed) +function Utils.generate_statistical_model(bench::RankingBenchmark; seed=nothing) (; nb_features) = bench + Random.seed!(seed) return Chain(Dense(nb_features => 1; bias=false), vec) end