Skip to content

Commit 335971c

Browse files
committed
Implement argmax
1 parent de6213c commit 335971c

File tree

4 files changed

+125
-3
lines changed

4 files changed

+125
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
/docs/src/index.md
33
data
44
scripts
5+
.DS_Store
56

67
# Files generated by invoking Julia with --code-coverage
78
*.jl.cov

src/Argmax/Argmax.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module Argmax
2+
3+
using ..Utils
4+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
5+
using Flux: Chain, Dense
6+
using Random
7+
8+
"""
9+
$TYPEDEF
10+
11+
Benchmark problem with an argmax as the CO algorithm.
12+
13+
# Fields
14+
$TYPEDFIELDS
15+
"""
16+
struct ArgmaxBenchmark <: AbstractBenchmark
17+
"iinstances dimension, total number of classes"
18+
instance_dim::Int
19+
"number of features"
20+
nb_features::Int
21+
end
22+
23+
function Base.show(io::IO, bench::ArgmaxBenchmark)
24+
(; instance_dim, nb_features) = bench
25+
return print(
26+
io, "ArgmaxBenchmark(instance_dim=$instance_dim, nb_features=$nb_features)"
27+
)
28+
end
29+
30+
function ArgmaxBenchmark(; instance_dim::Int=10, nb_features::Int=5)
31+
return ArgmaxBenchmark(instance_dim, nb_features)
32+
end
33+
34+
"""
35+
$TYPEDSIGNATURES
36+
37+
One-hot encoding of the argmax function.
38+
"""
39+
function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real}
40+
e = zeros(R, length(z))
41+
e[argmax(z)] = one(R)
42+
return e
43+
end
44+
45+
"""
46+
$TYPEDSIGNATURES
47+
48+
Return a top k maximizer.
49+
"""
50+
function Utils.generate_maximizer(bench::ArgmaxBenchmark)
51+
return one_hot_argmax
52+
end
53+
54+
"""
55+
$TYPEDSIGNATURES
56+
57+
Generate a dataset of labeled instances for the subset selection problem.
58+
The mapping between features and cost is identity.
59+
"""
60+
function Utils.generate_dataset(bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0)
61+
(; instance_dim, nb_features) = bench
62+
rng = MersenneTwister(seed)
63+
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
64+
mapping = Chain(Dense(nb_features => 1; bias=false), vec)
65+
costs = mapping.(features)
66+
solutions = one_hot_argmax.(costs)
67+
return [
68+
DataSample(; x, θ_true, y_true) for
69+
(x, θ_true, y_true) in zip(features, costs, solutions)
70+
]
71+
end
72+
73+
"""
74+
$TYPEDSIGNATURES
75+
76+
Initialize a linear model for `bench` using `Flux`.
77+
"""
78+
function Utils.generate_statistical_model(bench::ArgmaxBenchmark; seed=0)
79+
Random.seed!(seed)
80+
(; nb_features) = bench
81+
return Chain(Dense(nb_features => 1; bias=false), vec)
82+
end
83+
84+
export ArgmaxBenchmark
85+
86+
end

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@ end
1919

2020
include("Utils/Utils.jl")
2121

22+
include("Argmax/Argmax.jl")
23+
include("Ranking/Ranking.jl")
24+
include("SubsetSelection/SubsetSelection.jl")
2225
include("Warcraft/Warcraft.jl")
2326
include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
2427
include("PortfolioOptimization/PortfolioOptimization.jl")
25-
include("SubsetSelection/SubsetSelection.jl")
2628

2729
using .Utils
30+
using .Argmax
31+
using .SubsetSelection
2832
using .Warcraft
2933
using .FixedSizeShortestPath
3034
using .PortfolioOptimization
31-
using .SubsetSelection
3235

3336
# Interface
3437
export AbstractBenchmark, DataSample
@@ -39,9 +42,10 @@ export plot_data
3942
export compute_gap
4043

4144
# Export all benchmarks
45+
export ArgmaxBenchmark
46+
export SubsetSelectionBenchmark
4247
export WarcraftBenchmark
4348
export FixedSizeShortestPathBenchmark
4449
export PortfolioOptimizationBenchmark
45-
export SubsetSelectionBenchmark
4650

4751
end # module DecisionFocusedLearningBenchmarks

test/argmax.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
@testitem "Argmax" begin
2+
using DecisionFocusedLearningBenchmarks
3+
4+
instance_dim = 10
5+
nb_features = 5
6+
7+
b = ArgmaxBenchmark(; instance_dim=instance_dim, nb_features=nb_features)
8+
9+
io = IOBuffer()
10+
show(io, b)
11+
@test String(take!(io)) == "ArgmaxBenchmark(instance_dim=10, nb_features=5)"
12+
13+
dataset = generate_dataset(b, 50)
14+
model = generate_statistical_model(b)
15+
maximizer = generate_maximizer(b)
16+
17+
for (i, sample) in enumerate(dataset)
18+
(; x, θ_true, y_true) = sample
19+
@test size(x) == (nb_features, instance_dim)
20+
@test length(θ_true) == instance_dim
21+
@test length(y_true) == instance_dim
22+
@test isnothing(sample.instance)
23+
@test all(y_true .== maximizer(θ_true))
24+
25+
θ = model(x)
26+
@test length(θ) == instance_dim
27+
28+
y = maximizer(θ)
29+
@test length(y) == instance_dim
30+
end
31+
end

0 commit comments

Comments
 (0)