@@ -40,11 +40,21 @@ Custom constructor for [`Argmax2DBenchmark`](@ref).
4040"""
4141function Argmax2DBenchmark (; nb_features:: Int = 5 , seed= nothing , polytope_vertex_range= [6 ])
4242 Random. seed! (seed)
43+ <<<<<< < Updated upstream
4344 model = Chain (Dense (nb_features => 2 ; bias= false ), vec)
4445 return Argmax2DBenchmark (nb_features, model, polytope_vertex_range)
4546end
4647
4748maximizer (θ; instance) = instance[argmax (dot (θ, v) for v in instance)]
49+ ====== =
50+ model = Dense (nb_features => 2 ; bias= false )
51+ return Argmax2DBenchmark (nb_features, model, polytope_vertex_range)
52+ end
53+
54+ Utils. is_minimization_problem (:: Argmax2DBenchmark ) = false
55+
56+ maximizer (θ; instance, kwargs... ) = instance[argmax (dot (θ, v) for v in instance)]
57+ >>>>>> > Stashed changes
4858
4959"""
5060$TYPEDSIGNATURES
@@ -56,7 +66,11 @@ function Utils.generate_dataset(
5666)
5767 (; nb_features, encoder, polytope_vertex_range) = bench
5868 return map (1 : dataset_size) do _
69+ <<<<<< < Updated upstream
5970 x = randn (rng, nb_features)
71+ ====== =
72+ x = randn (rng, Float32, nb_features)
73+ >>>>>> > Stashed changes
6074 θ_true = encoder (x)
6175 θ_true ./= 2 * norm (θ_true)
6276 instance = build_polytope (rand (rng, polytope_vertex_range); shift= rand (rng))
@@ -84,23 +98,50 @@ function Utils.generate_statistical_model(
8498)
8599 Random. seed! (rng, seed)
86100 (; nb_features) = bench
101+ <<<<<< < Updated upstream
87102 model = Chain (Dense (nb_features => 2 ; bias= false ), vec)
88103 return model
89104end
90105
106+ ====== =
107+ model = Dense (nb_features => 2 ; bias= false )
108+ return model
109+ end
110+
111+ function Utils. plot_data (:: Argmax2DBenchmark ; instance, θ, kwargs... )
112+ pl = init_plot ()
113+ plot_polytope! (pl, instance)
114+ plot_objective! (pl, θ)
115+ return plot_maximizer! (pl, θ, instance, maximizer)
116+ end
117+
118+ >>>>>> > Stashed changes
91119"""
92120$TYPEDSIGNATURES
93121
94122Plot the data sample for the [`Argmax2DBenchmark`](@ref).
95123"""
96124function Utils. plot_data (
125+ <<<<<< < Updated upstream
97126 :: Argmax2DBenchmark , sample:: DataSample ; θ_true= sample. θ_true, kwargs...
98127)
99128 (; instance) = sample
100129 pl = init_plot ()
101130 plot_polytope! (pl, instance)
102131 plot_objective! (pl, θ_true)
103132 return plot_maximizer! (pl, θ_true, instance, maximizer)
133+ ====== =
134+ :: Argmax2DBenchmark ,
135+ sample:: DataSample ;
136+ instance= sample. instance,
137+ θ= sample. θ_true,
138+ kwargs... ,
139+ )
140+ pl = init_plot ()
141+ plot_polytope! (pl, instance)
142+ plot_objective! (pl, θ)
143+ return plot_maximizer! (pl, θ, instance, maximizer)
144+ >>>>>> > Stashed changes
104145end
105146
106147export Argmax2DBenchmark
0 commit comments