Skip to content

Commit 0a5f643

Browse files
committed
Merge branch 'main' into dynamic-interface-poc
2 parents 5f29047 + a98a501 commit 0a5f643

File tree

9 files changed

+244
-2
lines changed

9 files changed

+244
-2
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "DecisionFocusedLearningBenchmarks"
22
uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
33
authors = ["Members of JuliaDecisionFocusedLearning"]
4-
version = "0.2.2"
4+
version = "0.2.4"
55

66
[deps]
77
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
88
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
9+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
910
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
1011
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
1112
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -19,6 +20,7 @@ Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
1920
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
2021
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
2122
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
23+
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
2224
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2325
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
2426
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
@@ -35,6 +37,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3537
[compat]
3638
Combinatorics = "1.0.3"
3739
CommonRLInterface = "0.3.3"
40+
Colors = "0.13.1"
3841
ConstrainedShortestPaths = "0.6.0"
3942
DataDeps = "0.7"
4043
Distributions = "0.25"
@@ -48,6 +51,7 @@ Ipopt = "1.6"
4851
IterTools = "1.10.0"
4952
JSON = "0.21.4"
5053
JuMP = "1.22"
54+
LaTeXStrings = "1.4.0"
5155
LinearAlgebra = "1"
5256
Metalhead = "0.9.4"
5357
NPZ = "0.4"

docs/src/api/argmax_2d.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Argmax2D
2+
3+
## Public
4+
5+
```@autodocs
6+
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
7+
Private = false
8+
```
9+
10+
## Private
11+
12+
```@autodocs
13+
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
14+
Public = false
15+
```

src/Argmax/Argmax.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothin
4040
return ArgmaxBenchmark(instance_dim, nb_features, model)
4141
end
4242

43+
function Utils.is_minimization_problem(::ArgmaxBenchmark)
44+
return false
45+
end
46+
4347
"""
4448
$TYPEDSIGNATURES
4549

src/Argmax2D/Argmax2D.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
module Argmax2D
2+
3+
using ..Utils
4+
using Colors: Colors
5+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
6+
using Flux: Chain, Dense
7+
using LaTeXStrings: @L_str
8+
using LinearAlgebra: dot, norm
9+
using Plots: Plots
10+
using Random: Random, MersenneTwister
11+
12+
include("polytope.jl")
13+
14+
"""
15+
$TYPEDEF
16+
17+
Argmax becnhmark on a 2d polytope.
18+
19+
# Fields
20+
$TYPEDFIELDS
21+
"""
22+
struct Argmax2DBenchmark{E,R} <: AbstractBenchmark
23+
"number of features"
24+
nb_features::Int
25+
"true mapping between features and costs"
26+
encoder::E
27+
""
28+
polytope_vertex_range::R
29+
end
30+
31+
function Base.show(io::IO, bench::Argmax2DBenchmark)
32+
(; nb_features) = bench
33+
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
34+
end
35+
36+
"""
37+
$TYPEDSIGNATURES
38+
39+
Custom constructor for [`Argmax2DBenchmark`](@ref).
40+
"""
41+
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
42+
Random.seed!(seed)
43+
model = Dense(nb_features => 2; bias=false)
44+
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
45+
end
46+
47+
function Utils.is_minimization_problem(::Argmax2DBenchmark)
48+
return false
49+
end
50+
51+
maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instance)]
52+
53+
"""
54+
$TYPEDSIGNATURES
55+
56+
Generate a dataset for the [`Argmax2DBenchmark`](@ref).
57+
"""
58+
function Utils.generate_dataset(
59+
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
60+
)
61+
(; nb_features, encoder, polytope_vertex_range) = bench
62+
return map(1:dataset_size) do _
63+
x = randn(rng, Float32, nb_features)
64+
θ_true = encoder(x)
65+
θ_true ./= 2 * norm(θ_true)
66+
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
67+
y_true = maximizer(θ_true; instance)
68+
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
69+
end
70+
end
71+
72+
"""
73+
$TYPEDSIGNATURES
74+
75+
Maximizer for the [`Argmax2DBenchmark`](@ref).
76+
"""
77+
function Utils.generate_maximizer(::Argmax2DBenchmark)
78+
return maximizer
79+
end
80+
81+
"""
82+
$TYPEDSIGNATURES
83+
84+
Generate a statistical model for the [`Argmax2DBenchmark`](@ref).
85+
"""
86+
function Utils.generate_statistical_model(
87+
bench::Argmax2DBenchmark; seed=nothing, rng=MersenneTwister(seed)
88+
)
89+
Random.seed!(rng, seed)
90+
(; nb_features) = bench
91+
model = Dense(nb_features => 2; bias=false)
92+
return model
93+
end
94+
95+
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
96+
pl = init_plot()
97+
plot_polytope!(pl, instance)
98+
plot_objective!(pl, θ)
99+
return plot_maximizer!(pl, θ, instance, maximizer)
100+
end
101+
102+
"""
103+
$TYPEDSIGNATURES
104+
105+
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
106+
"""
107+
function Utils.plot_data(
108+
bench::Argmax2DBenchmark,
109+
sample::DataSample;
110+
instance=sample.instance,
111+
θ=sample.θ_true,
112+
kwargs...,
113+
)
114+
return Utils.plot_data(bench; instance, θ, kwargs...)
115+
end
116+
117+
export Argmax2DBenchmark
118+
119+
end

src/Argmax2D/polytope.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
function build_polytope(N; shift=0.0)
2+
return [[cospi(2k / N + shift), sinpi(2k / N + shift)] for k in 0:(N - 1)]
3+
end
4+
5+
function init_plot(title="")
6+
pl = Plots.plot(;
7+
aspect_ratio=:equal,
8+
legend=:outerleft,
9+
xlim=(-1.1, 1.1),
10+
ylim=(-1.1, 1.1),
11+
title=title,
12+
)
13+
return pl
14+
end;
15+
16+
function plot_polytope!(pl, vertices)
17+
return Plots.plot!(
18+
vcat(map(first, vertices), first(vertices[1])),
19+
vcat(map(last, vertices), last(vertices[1]));
20+
fillrange=0,
21+
fillcolor=:gray,
22+
fillalpha=0.2,
23+
linecolor=:black,
24+
label=L"\mathrm{conv}(\mathcal{Y}(x))",
25+
)
26+
end;
27+
28+
function plot_objective!(pl, θ)
29+
Plots.plot!(
30+
pl,
31+
[0.0, θ[1]],
32+
[0.0, θ[2]];
33+
color=Colors.JULIA_LOGO_COLORS.purple,
34+
arrow=true,
35+
lw=2,
36+
label=nothing,
37+
)
38+
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
39+
return pl
40+
end;
41+
42+
function plot_maximizer!(pl, θ, instance, maximizer)
43+
= maximizer(θ; instance)
44+
return Plots.scatter!(
45+
pl,
46+
[ŷ[1]],
47+
[ŷ[2]];
48+
color=Colors.JULIA_LOGO_COLORS.red,
49+
markersize=9,
50+
markershape=:square,
51+
label=L"f(\theta)",
52+
)
53+
end;

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
include("Utils/Utils.jl")
4949

5050
include("Argmax/Argmax.jl")
51+
include("Argmax2D/Argmax2D.jl")
5152
include("Ranking/Ranking.jl")
5253
include("SubsetSelection/SubsetSelection.jl")
5354
include("Warcraft/Warcraft.jl")
@@ -59,6 +60,7 @@ include("DynamicAssortment/DynamicAssortment.jl")
5960

6061
using .Utils
6162
using .Argmax
63+
using .Argmax2D
6264
using .Ranking
6365
using .SubsetSelection
6466
using .Warcraft
@@ -83,6 +85,7 @@ export compute_gap
8385

8486
# Export all benchmarks
8587
export ArgmaxBenchmark
88+
export Argmax2DBenchmark
8689
export RankingBenchmark
8790
export SubsetSelectionBenchmark
8891
export WarcraftBenchmark

src/gurobi_setup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using DocStringExtensions: TYPEDSIGNATURES
22
using JuMP: Model
33

4-
@info "Creating a GRB_ENV const for AircraftRoutingBase..."
4+
@info "Creating a GRB_ENV const for DecisionFocusedLearningBenchmarks..."
55
# Gurobi package setup (see https://github.com/jump-dev/Gurobi.jl/issues/424)
66
const GRB_ENV = Ref{Gurobi.Env}()
77
GRB_ENV[] = Gurobi.Env()

test/argmax.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
model = generate_statistical_model(b)
1515
maximizer = generate_maximizer(b)
1616

17+
gap = compute_gap(b, dataset, model, maximizer)
18+
@test gap >= 0
19+
1720
for (i, sample) in enumerate(dataset)
1821
(; x, θ_true, y_true) = sample
1922
@test size(x) == (nb_features, instance_dim)

test/argmax_2d.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
@testitem "Argmax2D" begin
2+
using DecisionFocusedLearningBenchmarks
3+
using Plots
4+
5+
nb_features = 5
6+
b = Argmax2DBenchmark(; nb_features=nb_features)
7+
8+
io = IOBuffer()
9+
show(io, b)
10+
@test String(take!(io)) == "Argmax2DBenchmark(nb_features=5)"
11+
12+
dataset = generate_dataset(b, 50)
13+
model = generate_statistical_model(b)
14+
maximizer = generate_maximizer(b)
15+
16+
gap = compute_gap(b, dataset, model, maximizer)
17+
@test gap >= 0
18+
19+
# Test plot_data
20+
figure = plot_data(b, dataset[1])
21+
@test figure isa Plots.Plot
22+
23+
for (i, sample) in enumerate(dataset)
24+
(; x, θ_true, y_true, instance) = sample
25+
@test length(x) == nb_features
26+
@test length(θ_true) == 2
27+
@test length(y_true) == 2
28+
@test !isnothing(sample.instance)
29+
@test instance isa Vector{Vector{Float64}}
30+
@test all(length(vertex) == 2 for vertex in instance)
31+
@test y_true in instance
32+
@test y_true == maximizer(θ_true; instance=instance)
33+
34+
θ = model(x)
35+
@test length(θ) == 2
36+
37+
y = maximizer(θ; instance=instance)
38+
@test length(y) == 2
39+
@test y in instance
40+
end
41+
end

0 commit comments

Comments
 (0)