Skip to content

Commit fd4c1f8

Browse files
authored
Merge pull request #31 from JuliaDecisionFocusedLearning/2d-polytope
Argmax2DBenchmark
2 parents 7255da8 + cdfe200 commit fd4c1f8

File tree

6 files changed

+296
-1
lines changed

6 files changed

+296
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "DecisionFocusedLearningBenchmarks"
22
uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
33
authors = ["Members of JuliaDecisionFocusedLearning"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
7+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
78
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
89
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -14,6 +15,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
1415
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
1516
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
1617
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
18+
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
1719
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1820
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
1921
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
@@ -28,6 +30,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2830
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2931

3032
[compat]
33+
Colors = "0.13.1"
3134
ConstrainedShortestPaths = "0.6.0"
3235
DataDeps = "0.7"
3336
Distributions = "0.25"
@@ -38,6 +41,7 @@ HiGHS = "1.9"
3841
Images = "0.26.1"
3942
Ipopt = "1.6"
4043
JuMP = "1.22"
44+
LaTeXStrings = "1.4.0"
4145
LinearAlgebra = "1"
4246
Metalhead = "0.9.4"
4347
NPZ = "0.4"

docs/src/api/argmax_2d.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Argmax2D
2+
3+
## Public
4+
5+
```@autodocs
6+
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
7+
Private = false
8+
```
9+
10+
## Private
11+
12+
```@autodocs
13+
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
14+
Public = false
15+
```

src/Argmax2D/Argmax2D.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module Argmax2D
2+
3+
using ..Utils
4+
using Colors: Colors
5+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
6+
using Flux: Chain, Dense
7+
using LaTeXStrings: @L_str
8+
using LinearAlgebra: dot, norm
9+
using Plots: Plots
10+
using Random: Random, MersenneTwister
11+
12+
include("polytope.jl")
13+
14+
"""
15+
$TYPEDEF
16+
17+
Argmax becnhmark on a 2d polytope.
18+
19+
# Fields
20+
$TYPEDFIELDS
21+
"""
22+
struct Argmax2DBenchmark{E,R} <: AbstractBenchmark
23+
"number of features"
24+
nb_features::Int
25+
"true mapping between features and costs"
26+
encoder::E
27+
""
28+
polytope_vertex_range::R
29+
end
30+
31+
function Base.show(io::IO, bench::Argmax2DBenchmark)
32+
(; nb_features) = bench
33+
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
34+
end
35+
36+
"""
37+
$TYPEDSIGNATURES
38+
39+
Custom constructor for [`Argmax2DBenchmark`](@ref).
40+
"""
41+
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
42+
Random.seed!(seed)
43+
model = Chain(Dense(nb_features => 2; bias=false), vec)
44+
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
45+
end
46+
47+
maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]
48+
49+
"""
50+
$TYPEDSIGNATURES
51+
52+
Generate a dataset for the [`Argmax2DBenchmark`](@ref).
53+
"""
54+
function Utils.generate_dataset(
55+
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
56+
)
57+
(; nb_features, encoder, polytope_vertex_range) = bench
58+
return map(1:dataset_size) do _
59+
x = randn(rng, nb_features)
60+
θ_true = encoder(x)
61+
θ_true ./= 2 * norm(θ_true)
62+
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
63+
y_true = maximizer(θ_true; instance)
64+
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
65+
end
66+
end
67+
68+
"""
69+
$TYPEDSIGNATURES
70+
71+
Maximizer for the [`Argmax2DBenchmark`](@ref).
72+
"""
73+
function Utils.generate_maximizer(::Argmax2DBenchmark)
74+
return maximizer
75+
end
76+
77+
"""
78+
$TYPEDSIGNATURES
79+
80+
Generate a statistical model for the [`Argmax2DBenchmark`](@ref).
81+
"""
82+
function Utils.generate_statistical_model(
83+
bench::Argmax2DBenchmark; seed=nothing, rng=MersenneTwister(seed)
84+
)
85+
Random.seed!(rng, seed)
86+
(; nb_features) = bench
87+
model = Chain(Dense(nb_features => 2; bias=false), vec)
88+
return model
89+
end
90+
91+
"""
92+
$TYPEDSIGNATURES
93+
94+
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
95+
"""
96+
function Utils.plot_data(
97+
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
98+
)
99+
(; instance) = sample
100+
pl = init_plot()
101+
plot_polytope!(pl, instance)
102+
plot_objective!(pl, θ_true)
103+
return plot_maximizer!(pl, θ_true, instance, maximizer)
104+
end
105+
106+
export Argmax2DBenchmark
107+
108+
end

src/Argmax2D/polytope.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
function build_polytope(N; shift=0.0)
2+
return [[cospi(2k / N + shift), sinpi(2k / N + shift)] for k in 0:(N - 1)]
3+
end
4+
5+
function init_plot(title="")
6+
pl = Plots.plot(;
7+
aspect_ratio=:equal,
8+
legend=:outerleft,
9+
xlim=(-1.1, 1.1),
10+
ylim=(-1.1, 1.1),
11+
title=title,
12+
)
13+
return pl
14+
end;
15+
16+
function plot_polytope!(pl, vertices)
17+
return Plots.plot!(
18+
vcat(map(first, vertices), first(vertices[1])),
19+
vcat(map(last, vertices), last(vertices[1]));
20+
fillrange=0,
21+
fillcolor=:gray,
22+
fillalpha=0.2,
23+
linecolor=:black,
24+
label=L"\mathrm{conv}(\mathcal{V})",
25+
)
26+
end;
27+
28+
const logocolors = Colors.JULIA_LOGO_COLORS
29+
30+
function plot_objective!(pl, θ)
31+
Plots.plot!(
32+
pl,
33+
[0.0, θ[1]],
34+
[0.0, θ[2]];
35+
color=logocolors.purple,
36+
arrow=true,
37+
lw=2,
38+
label=nothing,
39+
)
40+
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
41+
return pl
42+
end;
43+
44+
function plot_maximizer!(pl, θ, instance, maximizer)
45+
= maximizer(θ; instance)
46+
return Plots.scatter!(
47+
pl,
48+
[ŷ[1]],
49+
[ŷ[2]];
50+
color=logocolors.red,
51+
markersize=9,
52+
markershape=:square,
53+
label=L"f(\theta)",
54+
)
55+
end;
56+
57+
# function get_angle(v)
58+
# @assert !(norm(v) ≈ 0)
59+
# v = v ./ norm(v)
60+
# if v[2] >= 0
61+
# return acos(v[1])
62+
# else
63+
# return π + acos(-v[1])
64+
# end
65+
# end;
66+
67+
# function plot_distribution!(pl, probadist)
68+
# A = probadist.atoms
69+
# As = sort(A; by=get_angle)
70+
# p = probadist.weights
71+
# Plots.plot!(
72+
# pl,
73+
# vcat(map(first, As), first(As[1])),
74+
# vcat(map(last, As), last(As[1]));
75+
# fillrange=0,
76+
# fillcolor=:blue,
77+
# fillalpha=0.1,
78+
# linestyle=:dash,
79+
# linecolor=logocolors.blue,
80+
# label=L"\mathrm{conv}(\hat{p}(\theta))",
81+
# )
82+
# return Plots.scatter!(
83+
# pl,
84+
# map(first, A),
85+
# map(last, A);
86+
# markersize=25 .* p .^ 0.5,
87+
# markercolor=logocolors.blue,
88+
# markerstrokewidth=0,
89+
# markeralpha=0.4,
90+
# label=L"\hat{p}(\theta)",
91+
# )
92+
# end;
93+
94+
# function plot_expectation!(pl, probadist)
95+
# ŷΩ = compute_expectation(probadist)
96+
# return scatter!(
97+
# pl,
98+
# [ŷΩ[1]],
99+
# [ŷΩ[2]];
100+
# color=logocolors.blue,
101+
# markersize=6,
102+
# markershape=:hexagon,
103+
# label=L"\hat{f}(\theta)",
104+
# )
105+
# end;
106+
107+
# function compress_distribution!(
108+
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
109+
# ) where {A,W}
110+
# (; atoms, weights) = probadist
111+
# to_delete = Int[]
112+
# for i in length(probadist):-1:1
113+
# ai = atoms[i]
114+
# for j in 1:(i - 1)
115+
# aj = atoms[j]
116+
# if isapprox(ai, aj; atol=atol)
117+
# weights[j] += weights[i]
118+
# push!(to_delete, i)
119+
# break
120+
# end
121+
# end
122+
# end
123+
# sort!(to_delete)
124+
# deleteat!(atoms, to_delete)
125+
# deleteat!(weights, to_delete)
126+
# return probadist
127+
# end;

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ end
2424
include("Utils/Utils.jl")
2525

2626
include("Argmax/Argmax.jl")
27+
include("Argmax2D/Argmax2D.jl")
2728
include("Ranking/Ranking.jl")
2829
include("SubsetSelection/SubsetSelection.jl")
2930
include("Warcraft/Warcraft.jl")
@@ -33,6 +34,7 @@ include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
3334

3435
using .Utils
3536
using .Argmax
37+
using .Argmax2D
3638
using .Ranking
3739
using .SubsetSelection
3840
using .Warcraft
@@ -51,6 +53,7 @@ export compute_gap
5153

5254
# Export all benchmarks
5355
export ArgmaxBenchmark
56+
export Argmax2DBenchmark
5457
export RankingBenchmark
5558
export SubsetSelectionBenchmark
5659
export WarcraftBenchmark

test/argmax_2d.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
@testitem "Argmax2D" begin
2+
using DecisionFocusedLearningBenchmarks
3+
using Plots
4+
5+
nb_features = 5
6+
b = Argmax2DBenchmark(; nb_features=nb_features)
7+
8+
io = IOBuffer()
9+
show(io, b)
10+
@test String(take!(io)) == "Argmax2DBenchmark(nb_features=5)"
11+
12+
dataset = generate_dataset(b, 50)
13+
model = generate_statistical_model(b)
14+
maximizer = generate_maximizer(b)
15+
16+
# Test plot_data
17+
figure = plot_data(b, dataset[1])
18+
@test figure isa Plots.Plot
19+
20+
for (i, sample) in enumerate(dataset)
21+
(; x, θ_true, y_true, instance) = sample
22+
@test length(x) == nb_features
23+
@test length(θ_true) == 2
24+
@test length(y_true) == 2
25+
@test !isnothing(sample.instance)
26+
@test instance isa Vector{Vector{Float64}}
27+
@test all(length(vertex) == 2 for vertex in instance)
28+
@test y_true in instance
29+
@test y_true == maximizer(θ_true; instance=instance)
30+
31+
θ = model(x)
32+
@test length(θ) == 2
33+
34+
y = maximizer(θ; instance=instance)
35+
@test length(y) == 2
36+
@test y in instance
37+
end
38+
end

0 commit comments

Comments
 (0)