Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DecisionFocusedLearningBenchmarks"
uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
authors = ["Members of JuliaDecisionFocusedLearning"]
version = "0.2.3"
version = "0.2.4"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down
4 changes: 4 additions & 0 deletions src/Argmax/Argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5, seed=nothin
return ArgmaxBenchmark(instance_dim, nb_features, model)
end

function Utils.is_minimization_problem(::ArgmaxBenchmark)
return false
end

"""
$TYPEDSIGNATURES

Expand Down
31 changes: 21 additions & 10 deletions src/Argmax2D/Argmax2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ Custom constructor for [`Argmax2DBenchmark`](@ref).
"""
function Argmax2DBenchmark(; nb_features::Int=5, seed=nothing, polytope_vertex_range=[6])
Random.seed!(seed)
model = Chain(Dense(nb_features => 2; bias=false), vec)
model = Dense(nb_features => 2; bias=false)
return Argmax2DBenchmark(nb_features, model, polytope_vertex_range)
end

maximizer(θ; instance) = instance[argmax(dot(θ, v) for v in instance)]
function Utils.is_minimization_problem(::Argmax2DBenchmark)
return false
end

maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instance)]

"""
$TYPEDSIGNATURES
Expand All @@ -56,7 +60,7 @@ function Utils.generate_dataset(
)
(; nb_features, encoder, polytope_vertex_range) = bench
return map(1:dataset_size) do _
x = randn(rng, nb_features)
x = randn(rng, Float32, nb_features)
θ_true = encoder(x)
θ_true ./= 2 * norm(θ_true)
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
Expand Down Expand Up @@ -84,23 +88,30 @@ function Utils.generate_statistical_model(
)
Random.seed!(rng, seed)
(; nb_features) = bench
model = Chain(Dense(nb_features => 2; bias=false), vec)
model = Dense(nb_features => 2; bias=false)
return model
end

function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
pl = init_plot()
plot_polytope!(pl, instance)
plot_objective!(pl, θ)
return plot_maximizer!(pl, θ, instance, maximizer)
end

"""
$TYPEDSIGNATURES

Plot the data sample for the [`Argmax2DBenchmark`](@ref).
"""
function Utils.plot_data(
::Argmax2DBenchmark, sample::DataSample; θ_true=sample.θ_true, kwargs...
bench::Argmax2DBenchmark,
sample::DataSample;
instance=sample.instance,
θ=sample.θ_true,
kwargs...,
)
(; instance) = sample
pl = init_plot()
plot_polytope!(pl, instance)
plot_objective!(pl, θ_true)
return plot_maximizer!(pl, θ_true, instance, maximizer)
return Utils.plot_data(bench; instance, θ, kwargs...)
end

export Argmax2DBenchmark
Expand Down
80 changes: 3 additions & 77 deletions src/Argmax2D/polytope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@ function plot_polytope!(pl, vertices)
fillcolor=:gray,
fillalpha=0.2,
linecolor=:black,
label=L"\mathrm{conv}(\mathcal{V})",
label=L"\mathrm{conv}(\mathcal{Y}(x))",
)
end;

const logocolors = Colors.JULIA_LOGO_COLORS

function plot_objective!(pl, θ)
Plots.plot!(
pl,
[0.0, θ[1]],
[0.0, θ[2]];
color=logocolors.purple,
color=Colors.JULIA_LOGO_COLORS.purple,
arrow=true,
lw=2,
label=nothing,
Expand All @@ -47,81 +45,9 @@ function plot_maximizer!(pl, θ, instance, maximizer)
pl,
[ŷ[1]],
[ŷ[2]];
color=logocolors.red,
color=Colors.JULIA_LOGO_COLORS.red,
markersize=9,
markershape=:square,
label=L"f(\theta)",
)
end;

# function get_angle(v)
# @assert !(norm(v) ≈ 0)
# v = v ./ norm(v)
# if v[2] >= 0
# return acos(v[1])
# else
# return π + acos(-v[1])
# end
# end;

# function plot_distribution!(pl, probadist)
# A = probadist.atoms
# As = sort(A; by=get_angle)
# p = probadist.weights
# Plots.plot!(
# pl,
# vcat(map(first, As), first(As[1])),
# vcat(map(last, As), last(As[1]));
# fillrange=0,
# fillcolor=:blue,
# fillalpha=0.1,
# linestyle=:dash,
# linecolor=logocolors.blue,
# label=L"\mathrm{conv}(\hat{p}(\theta))",
# )
# return Plots.scatter!(
# pl,
# map(first, A),
# map(last, A);
# markersize=25 .* p .^ 0.5,
# markercolor=logocolors.blue,
# markerstrokewidth=0,
# markeralpha=0.4,
# label=L"\hat{p}(\theta)",
# )
# end;

# function plot_expectation!(pl, probadist)
# ŷΩ = compute_expectation(probadist)
# return scatter!(
# pl,
# [ŷΩ[1]],
# [ŷΩ[2]];
# color=logocolors.blue,
# markersize=6,
# markershape=:hexagon,
# label=L"\hat{f}(\theta)",
# )
# end;

# function compress_distribution!(
# probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
# ) where {A,W}
# (; atoms, weights) = probadist
# to_delete = Int[]
# for i in length(probadist):-1:1
# ai = atoms[i]
# for j in 1:(i - 1)
# aj = atoms[j]
# if isapprox(ai, aj; atol=atol)
# weights[j] += weights[i]
# push!(to_delete, i)
# break
# end
# end
# end
# sort!(to_delete)
# deleteat!(atoms, to_delete)
# deleteat!(weights, to_delete)
# return probadist
# end;
3 changes: 3 additions & 0 deletions test/argmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
model = generate_statistical_model(b)
maximizer = generate_maximizer(b)

gap = compute_gap(b, dataset, model, maximizer)
@test gap >= 0

for (i, sample) in enumerate(dataset)
(; x, θ_true, y_true) = sample
@test size(x) == (nb_features, instance_dim)
Expand Down
3 changes: 3 additions & 0 deletions test/argmax_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
model = generate_statistical_model(b)
maximizer = generate_maximizer(b)

gap = compute_gap(b, dataset, model, maximizer)
@test gap >= 0

# Test plot_data
figure = plot_data(b, dataset[1])
@test figure isa Plots.Plot
Expand Down