@@ -40,11 +40,15 @@ Custom constructor for [`Argmax2DBenchmark`](@ref).
4040"""
4141function Argmax2DBenchmark (; nb_features:: Int = 5 , seed= nothing , polytope_vertex_range= [6 ])
4242 Random. seed! (seed)
43- model = Chain ( Dense (nb_features => 2 ; bias= false ), vec )
43+ model = Dense (nb_features => 2 ; bias= false )
4444 return Argmax2DBenchmark (nb_features, model, polytope_vertex_range)
4545end
4646
47- maximizer (θ; instance) = instance[argmax (dot (θ, v) for v in instance)]
47+ function Utils. is_minimization_problem (:: Argmax2DBenchmark )
48+ return false
49+ end
50+
51+ maximizer (θ; instance, kwargs... ) = instance[argmax (dot (θ, v) for v in instance)]
4852
4953"""
5054$TYPEDSIGNATURES
@@ -56,7 +60,7 @@ function Utils.generate_dataset(
5660)
5761 (; nb_features, encoder, polytope_vertex_range) = bench
5862 return map (1 : dataset_size) do _
59- x = randn (rng, nb_features)
63+ x = randn (rng, Float32, nb_features)
6064 θ_true = encoder (x)
6165 θ_true ./= 2 * norm (θ_true)
6266 instance = build_polytope (rand (rng, polytope_vertex_range); shift= rand (rng))
@@ -84,23 +88,30 @@ function Utils.generate_statistical_model(
8488)
8589 Random. seed! (rng, seed)
8690 (; nb_features) = bench
87- model = Chain ( Dense (nb_features => 2 ; bias= false ), vec )
91+ model = Dense (nb_features => 2 ; bias= false )
8892 return model
8993end
9094
95+ function Utils. plot_data (:: Argmax2DBenchmark ; instance, θ, kwargs... )
96+ pl = init_plot ()
97+ plot_polytope! (pl, instance)
98+ plot_objective! (pl, θ)
99+ return plot_maximizer! (pl, θ, instance, maximizer)
100+ end
101+
91102"""
92103$TYPEDSIGNATURES
93104
94105Plot the data sample for the [`Argmax2DBenchmark`](@ref).
95106"""
96107function Utils. plot_data (
97- :: Argmax2DBenchmark , sample:: DataSample ; θ_true= sample. θ_true, kwargs...
108+ bench:: Argmax2DBenchmark ,
109+ sample:: DataSample ;
110+ instance= sample. instance,
111+ θ= sample. θ_true,
112+ kwargs... ,
98113)
99- (; instance) = sample
100- pl = init_plot ()
101- plot_polytope! (pl, instance)
102- plot_objective! (pl, θ_true)
103- return plot_maximizer! (pl, θ_true, instance, maximizer)
114+ return Utils. plot_data (bench; instance, θ, kwargs... )
104115end
105116
106117export Argmax2DBenchmark
0 commit comments