Skip to content

Commit 17883d2

Browse files
committed
Argmax2DBenchmark
1 parent 7255da8 commit 17883d2

File tree

5 files changed

+255
-0
lines changed

5 files changed

+255
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Members of JuliaDecisionFocusedLearning"]
44
version = "0.2.2"
55

66
[deps]
7+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
78
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
89
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -14,6 +15,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
1415
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
1516
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
1617
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
18+
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
1719
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1820
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
1921
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
@@ -28,6 +30,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2830
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2931

3032
[compat]
33+
Colors = "0.13.1"
3134
ConstrainedShortestPaths = "0.6.0"
3235
DataDeps = "0.7"
3336
Distributions = "0.25"
@@ -38,6 +41,7 @@ HiGHS = "1.9"
3841
Images = "0.26.1"
3942
Ipopt = "1.6"
4043
JuMP = "1.22"
44+
LaTeXStrings = "1.4.0"
4145
LinearAlgebra = "1"
4246
Metalhead = "0.9.4"
4347
NPZ = "0.4"

src/Argmax2D/Argmax2D.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
module Argmax2D
2+
3+
using ..Utils
4+
using Colors: Colors
5+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
6+
using Flux: Chain, Dense
7+
using LaTeXStrings: @L_str
8+
using LinearAlgebra: dot, norm
9+
using Plots: Plots
10+
using Random: Random, MersenneTwister
11+
12+
include("polytope.jl")
13+
14+
"""
15+
$TYPEDEF
16+
17+
Argmax becnhmark on a 2d polytope.
18+
19+
# Fields
20+
$TYPEDFIELDS
21+
"""
22+
struct Argmax2DBenchmark{E,R} <: AbstractBenchmark
23+
"number of features"
24+
nb_features::Int
25+
"true mapping between features and costs"
26+
encoder::E
27+
""
28+
polytope_vertex_range::R
29+
end
30+
31+
function Base.show(io::IO, bench::Argmax2DBenchmark)
32+
(; nb_features) = bench
33+
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
34+
end
35+
36+
"""
37+
$TYPEDSIGNATURES
38+
39+
Custom constructor for [`ArgmaxBenchmark`](@ref).
40+
"""
41+
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
42+
Random.seed!(seed)
43+
model = Chain(Dense(nb_features => 2; bias=false), vec)
44+
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
45+
end
46+
47+
maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]
48+
49+
function Utils.generate_dataset(
50+
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
51+
)
52+
(; nb_features, encoder, polytope_vertex_range) = bench
53+
X = [randn(rng, nb_features) for _ in 1:dataset_size]
54+
θs = encoder.(X)
55+
θs ./= 2 * norm.(θs)
56+
instances = [
57+
build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng)) for
58+
_ in 1:dataset_size
59+
]
60+
Y = [maximizer(θ; instance) for (θ, instance) in zip(θs, instances)]
61+
return [
62+
DataSample(; x, θ_true, y_true, instance) for
63+
(x, θ_true, y_true, instance) in zip(X, θs, Y, instances)
64+
]
65+
end
66+
67+
Utils.generate_maximizer(::Argmax2DBenchmark) = maximizer
68+
69+
function Utils.generate_statistical_model(
70+
bench::Argmax2DBenchmark; seed=nothing, rng=MersenneTwister(seed)
71+
)
72+
Random.seed!(rng, seed)
73+
(; nb_features) = bench
74+
model = Chain(Dense(nb_features => 2; bias=false), vec)
75+
return model
76+
end
77+
78+
function Utils.plot_data(
79+
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
80+
)
81+
(; instance) = sample
82+
pl = init_plot()
83+
plot_polytope!(pl, instance)
84+
plot_objective!(pl, θ_true)
85+
return plot_maximizer!(pl, θ_true, instance, maximizer)
86+
end
87+
88+
export Argmax2DBenchmark
89+
90+
end

src/Argmax2D/polytope.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
function build_polytope(N; shift=0.0)
2+
return [[cospi(2k / N + shift), sinpi(2k / N + shift)] for k in 0:(N - 1)]
3+
end
4+
5+
function init_plot(title="")
6+
pl = Plots.plot(;
7+
aspect_ratio=:equal,
8+
legend=:outerleft,
9+
xlim=(-1.1, 1.1),
10+
ylim=(-1.1, 1.1),
11+
title=title,
12+
)
13+
return pl
14+
end;
15+
16+
function plot_polytope!(pl, vertices)
17+
return Plots.plot!(
18+
vcat(map(first, vertices), first(vertices[1])),
19+
vcat(map(last, vertices), last(vertices[1]));
20+
fillrange=0,
21+
fillcolor=:gray,
22+
fillalpha=0.2,
23+
linecolor=:black,
24+
label=L"\mathrm{conv}(\mathcal{V})",
25+
)
26+
end;
27+
28+
const logocolors = Colors.JULIA_LOGO_COLORS
29+
30+
function plot_objective!(pl, θ)
31+
Plots.plot!(
32+
pl,
33+
[0.0, θ[1]],
34+
[0.0, θ[2]];
35+
color=logocolors.purple,
36+
arrow=true,
37+
lw=2,
38+
label=nothing,
39+
)
40+
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
41+
return pl
42+
end;
43+
44+
function plot_maximizer!(pl, θ, instance, maximizer)
45+
= maximizer(θ; instance)
46+
return Plots.scatter!(
47+
pl,
48+
[ŷ[1]],
49+
[ŷ[2]];
50+
color=logocolors.red,
51+
markersize=9,
52+
markershape=:square,
53+
label=L"f(\theta)",
54+
)
55+
end;
56+
57+
# function get_angle(v)
58+
# @assert !(norm(v) ≈ 0)
59+
# v = v ./ norm(v)
60+
# if v[2] >= 0
61+
# return acos(v[1])
62+
# else
63+
# return π + acos(-v[1])
64+
# end
65+
# end;
66+
67+
# function plot_distribution!(pl, probadist)
68+
# A = probadist.atoms
69+
# As = sort(A; by=get_angle)
70+
# p = probadist.weights
71+
# Plots.plot!(
72+
# pl,
73+
# vcat(map(first, As), first(As[1])),
74+
# vcat(map(last, As), last(As[1]));
75+
# fillrange=0,
76+
# fillcolor=:blue,
77+
# fillalpha=0.1,
78+
# linestyle=:dash,
79+
# linecolor=logocolors.blue,
80+
# label=L"\mathrm{conv}(\hat{p}(\theta))",
81+
# )
82+
# return Plots.scatter!(
83+
# pl,
84+
# map(first, A),
85+
# map(last, A);
86+
# markersize=25 .* p .^ 0.5,
87+
# markercolor=logocolors.blue,
88+
# markerstrokewidth=0,
89+
# markeralpha=0.4,
90+
# label=L"\hat{p}(\theta)",
91+
# )
92+
# end;
93+
94+
# function plot_expectation!(pl, probadist)
95+
# ŷΩ = compute_expectation(probadist)
96+
# return scatter!(
97+
# pl,
98+
# [ŷΩ[1]],
99+
# [ŷΩ[2]];
100+
# color=logocolors.blue,
101+
# markersize=6,
102+
# markershape=:hexagon,
103+
# label=L"\hat{f}(\theta)",
104+
# )
105+
# end;
106+
107+
# function compress_distribution!(
108+
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
109+
# ) where {A,W}
110+
# (; atoms, weights) = probadist
111+
# to_delete = Int[]
112+
# for i in length(probadist):-1:1
113+
# ai = atoms[i]
114+
# for j in 1:(i - 1)
115+
# aj = atoms[j]
116+
# if isapprox(ai, aj; atol=atol)
117+
# weights[j] += weights[i]
118+
# push!(to_delete, i)
119+
# break
120+
# end
121+
# end
122+
# end
123+
# sort!(to_delete)
124+
# deleteat!(atoms, to_delete)
125+
# deleteat!(weights, to_delete)
126+
# return probadist
127+
# end;

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ end
2424
include("Utils/Utils.jl")
2525

2626
include("Argmax/Argmax.jl")
27+
include("Argmax2D/Argmax2D.jl")
2728
include("Ranking/Ranking.jl")
2829
include("SubsetSelection/SubsetSelection.jl")
2930
include("Warcraft/Warcraft.jl")
@@ -33,6 +34,7 @@ include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
3334

3435
using .Utils
3536
using .Argmax
37+
using .Argmax2D
3638
using .Ranking
3739
using .SubsetSelection
3840
using .Warcraft
@@ -51,6 +53,7 @@ export compute_gap
5153

5254
# Export all benchmarks
5355
export ArgmaxBenchmark
56+
export Argmax2DBenchmark
5457
export RankingBenchmark
5558
export SubsetSelectionBenchmark
5659
export WarcraftBenchmark

test/argmax_2d.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
@testitem "Argmax2D" begin
2+
using DecisionFocusedLearningBenchmarks
3+
4+
instance_dim = 10
5+
nb_features = 5
6+
7+
b = ArgmaxBenchmark(; instance_dim=instance_dim, nb_features=nb_features)
8+
9+
io = IOBuffer()
10+
show(io, b)
11+
@test String(take!(io)) == "ArgmaxBenchmark(instance_dim=10, nb_features=5)"
12+
13+
dataset = generate_dataset(b, 50)
14+
model = generate_statistical_model(b)
15+
maximizer = generate_maximizer(b)
16+
17+
for (i, sample) in enumerate(dataset)
18+
(; x, θ_true, y_true) = sample
19+
@test size(x) == (nb_features, instance_dim)
20+
@test length(θ_true) == instance_dim
21+
@test length(y_true) == instance_dim
22+
@test isnothing(sample.instance)
23+
@test all(y_true .== maximizer(θ_true))
24+
25+
θ = model(x)
26+
@test length(θ) == instance_dim
27+
28+
y = maximizer(θ)
29+
@test length(y) == instance_dim
30+
end
31+
end

0 commit comments

Comments
 (0)