Skip to content

Commit 028fc4b

Browse files
committed
Improve argmax benchmarks
1 parent fd4c1f8 commit 028fc4b

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

src/Argmax/Argmax.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothin
4040
return ArgmaxBenchmark(instance_dim, nb_features, model)
4141
end
4242

43+
Utils.is_minimization_problem(::ArgmaxBenchmark) = false
44+
4345
"""
4446
$TYPEDSIGNATURES
4547

src/Argmax2D/Argmax2D.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,21 @@ Custom constructor for [`Argmax2DBenchmark`](@ref).
4040
"""
4141
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
4242
Random.seed!(seed)
43+
<<<<<<< Updated upstream
4344
model = Chain(Dense(nb_features => 2; bias=false), vec)
4445
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
4546
end
4647

4748
maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]
49+
=======
50+
model = Dense(nb_features => 2; bias=false)
51+
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
52+
end
53+
54+
Utils.is_minimization_problem(::Argmax2DBenchmark) = false
55+
56+
maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instance)]
57+
>>>>>>> Stashed changes
4858

4959
"""
5060
$TYPEDSIGNATURES
@@ -56,7 +66,11 @@ function Utils.generate_dataset(
5666
)
5767
(; nb_features, encoder, polytope_vertex_range) = bench
5868
return map(1:dataset_size) do _
69+
<<<<<<< Updated upstream
5970
x = randn(rng, nb_features)
71+
=======
72+
x = randn(rng, Float32, nb_features)
73+
>>>>>>> Stashed changes
6074
θ_true = encoder(x)
6175
θ_true ./= 2 * norm(θ_true)
6276
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
@@ -84,23 +98,50 @@ function Utils.generate_statistical_model(
8498
)
8599
Random.seed!(rng, seed)
86100
(; nb_features) = bench
101+
<<<<<<< Updated upstream
87102
model = Chain(Dense(nb_features => 2; bias=false), vec)
88103
return model
89104
end
90105

106+
=======
107+
model = Dense(nb_features => 2; bias=false)
108+
return model
109+
end
110+
111+
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
112+
pl = init_plot()
113+
plot_polytope!(pl, instance)
114+
plot_objective!(pl, θ)
115+
return plot_maximizer!(pl, θ, instance, maximizer)
116+
end
117+
118+
>>>>>>> Stashed changes
91119
"""
92120
$TYPEDSIGNATURES
93121
94122
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
95123
"""
96124
function Utils.plot_data(
125+
<<<<<<< Updated upstream
97126
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
98127
)
99128
(; instance) = sample
100129
pl = init_plot()
101130
plot_polytope!(pl, instance)
102131
plot_objective!(pl, θ_true)
103132
return plot_maximizer!(pl, θ_true, instance, maximizer)
133+
=======
134+
::Argmax2DBenchmark,
135+
sample::DataSample;
136+
instance=sample.instance,
137+
θ=sample.θ_true,
138+
kwargs...,
139+
)
140+
pl = init_plot()
141+
plot_polytope!(pl, instance)
142+
plot_objective!(pl, θ)
143+
return plot_maximizer!(pl, θ, instance, maximizer)
144+
>>>>>>> Stashed changes
104145
end
105146

106147
export Argmax2DBenchmark

src/Argmax2D/polytope.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,29 @@ function plot_polytope!(pl, vertices)
2121
fillcolor=:gray,
2222
fillalpha=0.2,
2323
linecolor=:black,
24+
<<<<<<< Updated upstream
2425
label=L"\mathrm{conv}(\mathcal{V})",
2526
)
2627
end;
2728

2829
const logocolors = Colors.JULIA_LOGO_COLORS
2930

31+
=======
32+
label=L"\mathrm{conv}(\mathcal{Y}(x))",
33+
)
34+
end;
35+
36+
>>>>>>> Stashed changes
3037
function plot_objective!(pl, θ)
3138
Plots.plot!(
3239
pl,
3340
[0.0, θ[1]],
3441
[0.0, θ[2]];
42+
<<<<<<< Updated upstream
3543
color=logocolors.purple,
44+
=======
45+
color=Colors.JULIA_LOGO_COLORS.purple,
46+
>>>>>>> Stashed changes
3647
arrow=true,
3748
lw=2,
3849
label=nothing,
@@ -47,7 +58,11 @@ function plot_maximizer!(pl, θ, instance, maximizer)
4758
pl,
4859
[ŷ[1]],
4960
[ŷ[2]];
61+
<<<<<<< Updated upstream
5062
color=logocolors.red,
63+
=======
64+
color=Colors.JULIA_LOGO_COLORS.red,
65+
>>>>>>> Stashed changes
5166
markersize=9,
5267
markershape=:square,
5368
label=L"f(\theta)",
@@ -76,15 +91,23 @@ end;
7691
# fillcolor=:blue,
7792
# fillalpha=0.1,
7893
# linestyle=:dash,
94+
<<<<<<< Updated upstream
7995
# linecolor=logocolors.blue,
96+
=======
97+
# linecolor=Colors.JULIA_LOGO_COLORS.blue,
98+
>>>>>>> Stashed changes
8099
# label=L"\mathrm{conv}(\hat{p}(\theta))",
81100
# )
82101
# return Plots.scatter!(
83102
# pl,
84103
# map(first, A),
85104
# map(last, A);
86105
# markersize=25 .* p .^ 0.5,
106+
<<<<<<< Updated upstream
87107
# markercolor=logocolors.blue,
108+
=======
109+
# markercolor=Colors.JULIA_LOGO_COLORS.blue,
110+
>>>>>>> Stashed changes
88111
# markerstrokewidth=0,
89112
# markeralpha=0.4,
90113
# label=L"\hat{p}(\theta)",
@@ -97,15 +120,23 @@ end;
97120
# pl,
98121
# [ŷΩ[1]],
99122
# [ŷΩ[2]];
123+
<<<<<<< Updated upstream
100124
# color=logocolors.blue,
125+
=======
126+
# color=Colors.JULIA_LOGO_COLORS.blue,
127+
>>>>>>> Stashed changes
101128
# markersize=6,
102129
# markershape=:hexagon,
103130
# label=L"\hat{f}(\theta)",
104131
# )
105132
# end;
106133

107134
# function compress_distribution!(
135+
<<<<<<< Updated upstream
108136
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
137+
=======
138+
# probadist::DifferentiableExpectations.FixedAtomsProbabilityDistribution{A,W}; atol=0
139+
>>>>>>> Stashed changes
109140
# ) where {A,W}
110141
# (; atoms, weights) = probadist
111142
# to_delete = Int[]

0 commit comments

Comments
 (0)