Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "DecisionFocusedLearningBenchmarks"
uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
authors = ["Members of JuliaDecisionFocusedLearning"]
version = "0.2.2"
version = "0.2.3"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -14,6 +15,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Expand All @@ -28,6 +30,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Colors = "0.13.1"
ConstrainedShortestPaths = "0.6.0"
DataDeps = "0.7"
Distributions = "0.25"
Expand All @@ -38,6 +41,7 @@ HiGHS = "1.9"
Images = "0.26.1"
Ipopt = "1.6"
JuMP = "1.22"
LaTeXStrings = "1.4.0"
LinearAlgebra = "1"
Metalhead = "0.9.4"
NPZ = "0.4"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/api/argmax_2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Argmax2D

## Public

```@autodocs
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
Private = false
```

## Private

```@autodocs
Modules = [DecisionFocusedLearningBenchmarks.Argmax2D]
Public = false
```
108 changes: 108 additions & 0 deletions src/Argmax2D/Argmax2D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module Argmax2D

using ..Utils
using Colors: Colors
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using Flux: Chain, Dense
using LaTeXStrings: @L_str
using LinearAlgebra: dot, norm
using Plots: Plots
using Random: Random, MersenneTwister

include("polytope.jl")

"""
$TYPEDEF

Argmax becnhmark on a 2d polytope.

# Fields
$TYPEDFIELDS
"""
struct Argmax2DBenchmark{E,R} <: AbstractBenchmark
"number of features"
nb_features::Int
"true mapping between features and costs"
encoder::E
""
polytope_vertex_range::R
end

function Base.show(io::IO, bench::Argmax2DBenchmark)
(; nb_features) = bench
return print(io, "Argmax2DBenchmark(nb_features=$nb_features)")
end

"""
$TYPEDSIGNATURES

Custom constructor for [`Argmax2DBenchmark`](@ref).
"""
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
Random.seed!(seed)
model = Chain(Dense(nb_features => 2; bias=false), vec)
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
end

maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]

"""
$TYPEDSIGNATURES

Generate a dataset for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_dataset(
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
)
(; nb_features, encoder, polytope_vertex_range) = bench
return map(1:dataset_size) do _
x = randn(rng, nb_features)
θ_true = encoder(x)
θ_true ./= 2 * norm(θ_true)
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
y_true = maximizer(θ_true; instance)
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
end
end

"""
$TYPEDSIGNATURES

Maximizer for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_maximizer(::Argmax2DBenchmark)
return maximizer
end

"""
$TYPEDSIGNATURES

Generate a statistical model for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.generate_statistical_model(
bench::Argmax2DBenchmark; seed=nothing, rng=MersenneTwister(seed)
)
Random.seed!(rng, seed)
(; nb_features) = bench
model = Chain(Dense(nb_features => 2; bias=false), vec)
return model
end

"""
$TYPEDSIGNATURES

Plot the data sample for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.plot_data(
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
)
(; instance) = sample
pl = init_plot()
plot_polytope!(pl, instance)
plot_objective!(pl, θ_true)
return plot_maximizer!(pl, θ_true, instance, maximizer)
end

export Argmax2DBenchmark

end
127 changes: 127 additions & 0 deletions src/Argmax2D/polytope.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
function build_polytope(N; shift=0.0)
return [[cospi(2k / N + shift), sinpi(2k / N + shift)] for k in 0:(N - 1)]
end

function init_plot(title="")
pl = Plots.plot(;
aspect_ratio=:equal,
legend=:outerleft,
xlim=(-1.1, 1.1),
ylim=(-1.1, 1.1),
title=title,
)
return pl
end;

function plot_polytope!(pl, vertices)
return Plots.plot!(
vcat(map(first, vertices), first(vertices[1])),
vcat(map(last, vertices), last(vertices[1]));
fillrange=0,
fillcolor=:gray,
fillalpha=0.2,
linecolor=:black,
label=L"\mathrm{conv}(\mathcal{V})",
)
end;

const logocolors = Colors.JULIA_LOGO_COLORS

function plot_objective!(pl, θ)
Plots.plot!(
pl,
[0.0, θ[1]],
[0.0, θ[2]];
color=logocolors.purple,
arrow=true,
lw=2,
label=nothing,
)
Plots.annotate!(pl, [-0.2 * θ[1]], [-0.2 * θ[2]], [L"\theta"])
return pl
end;

function plot_maximizer!(pl, θ, instance, maximizer)
ŷ = maximizer(θ; instance)
return Plots.scatter!(
pl,
[ŷ[1]],
[ŷ[2]];
color=logocolors.red,
markersize=9,
markershape=:square,
label=L"f(\theta)",
)
end;

# function get_angle(v)
# @assert !(norm(v) ≈ 0)
# v = v ./ norm(v)
# if v[2] >= 0
# return acos(v[1])
# else
# return π + acos(-v[1])
# end
# end;

# function plot_distribution!(pl, probadist)
# A = probadist.atoms
# As = sort(A; by=get_angle)
# p = probadist.weights
# Plots.plot!(
# pl,
# vcat(map(first, As), first(As[1])),
# vcat(map(last, As), last(As[1]));
# fillrange=0,
# fillcolor=:blue,
# fillalpha=0.1,
# linestyle=:dash,
# linecolor=logocolors.blue,
# label=L"\mathrm{conv}(\hat{p}(\theta))",
# )
# return Plots.scatter!(
# pl,
# map(first, A),
# map(last, A);
# markersize=25 .* p .^ 0.5,
# markercolor=logocolors.blue,
# markerstrokewidth=0,
# markeralpha=0.4,
# label=L"\hat{p}(\theta)",
# )
# end;

# function plot_expectation!(pl, probadist)
# ŷΩ = compute_expectation(probadist)
# return scatter!(
# pl,
# [ŷΩ[1]],
# [ŷΩ[2]];
# color=logocolors.blue,
# markersize=6,
# markershape=:hexagon,
# label=L"\hat{f}(\theta)",
# )
# end;

# function compress_distribution!(
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
# ) where {A,W}
# (; atoms, weights) = probadist
# to_delete = Int[]
# for i in length(probadist):-1:1
# ai = atoms[i]
# for j in 1:(i - 1)
# aj = atoms[j]
# if isapprox(ai, aj; atol=atol)
# weights[j] += weights[i]
# push!(to_delete, i)
# break
# end
# end
# end
# sort!(to_delete)
# deleteat!(atoms, to_delete)
# deleteat!(weights, to_delete)
# return probadist
# end;
3 changes: 3 additions & 0 deletions src/DecisionFocusedLearningBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ end
include("Utils/Utils.jl")

include("Argmax/Argmax.jl")
include("Argmax2D/Argmax2D.jl")
include("Ranking/Ranking.jl")
include("SubsetSelection/SubsetSelection.jl")
include("Warcraft/Warcraft.jl")
Expand All @@ -33,6 +34,7 @@ include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")

using .Utils
using .Argmax
using .Argmax2D
using .Ranking
using .SubsetSelection
using .Warcraft
Expand All @@ -51,6 +53,7 @@ export compute_gap

# Export all benchmarks
export ArgmaxBenchmark
export Argmax2DBenchmark
export RankingBenchmark
export SubsetSelectionBenchmark
export WarcraftBenchmark
Expand Down
38 changes: 38 additions & 0 deletions test/argmax_2d.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testitem "Argmax2D" begin
using DecisionFocusedLearningBenchmarks
using Plots

nb_features = 5
b = Argmax2DBenchmark(; nb_features=nb_features)

io = IOBuffer()
show(io, b)
@test String(take!(io)) == "Argmax2DBenchmark(nb_features=5)"

dataset = generate_dataset(b, 50)
model = generate_statistical_model(b)
maximizer = generate_maximizer(b)

# Test plot_data
figure = plot_data(b, dataset[1])
@test figure isa Plots.Plot

for (i, sample) in enumerate(dataset)
(; x, θ_true, y_true, instance) = sample
@test length(x) == nb_features
@test length(θ_true) == 2
@test length(y_true) == 2
@test !isnothing(sample.instance)
@test instance isa Vector{Vector{Float64}}
@test all(length(vertex) == 2 for vertex in instance)
@test y_true in instance
@test y_true == maximizer(θ_true; instance=instance)

θ = model(x)
@test length(θ) == 2

y = maximizer(θ; instance=instance)
@test length(y) == 2
@test y in instance
end
end