Skip to content

Commit 5a2e852

Browse files
committed
Implement generate_sample interface
1 parent 2f406e0 commit 5a2e852

File tree

13 files changed

+144
-128
lines changed

13 files changed

+144
-128
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1313
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
1414
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
1515
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
16+
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
17+
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1618
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1719
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1820
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
@@ -37,6 +39,8 @@ Graphs = "1.11"
3739
HiGHS = "1.9"
3840
Images = "0.26.1"
3941
Ipopt = "1.6"
42+
IterTools = "1.10.0"
43+
JSON = "0.21.4"
4044
JuMP = "1.22"
4145
LinearAlgebra = "1"
4246
Metalhead = "0.9.4"

src/Argmax/Argmax.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,39 @@ end
6262

6363
"""
6464
$TYPEDSIGNATURES
65-
66-
Generate a dataset of labeled instances for the argmax problem.
6765
"""
68-
function Utils.generate_dataset(
69-
bench::ArgmaxBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
66+
function Utils.generate_sample(
67+
bench::ArgmaxBenchmark, rng::AbstractRNG; noise_std::Float32=0.0f0
7068
)
7169
(; instance_dim, nb_features, encoder) = bench
72-
rng = MersenneTwister(seed)
73-
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
74-
costs = encoder.(features)
75-
noisy_solutions = [
76-
one_hot_argmax+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
77-
]
78-
return [
79-
DataSample(; x, θ_true, y_true) for
80-
(x, θ_true, y_true) in zip(features, costs, noisy_solutions)
81-
]
70+
features = randn(rng, Float32, nb_features, instance_dim)
71+
costs = encoder(features)
72+
noisy_solution = one_hot_argmax(costs + noise_std * randn(rng, Float32, instance_dim))
73+
return DataSample(; x=features, θ_true=costs, y_true=noisy_solution)
8274
end
8375

76+
# """
77+
# $TYPEDSIGNATURES
78+
79+
# Generate a dataset of labeled instances for the argmax problem.
80+
# """
81+
# function Utils.generate_dataset(
82+
# bench::ArgmaxBenchmark, dataset_size::Int; noise_std=0.0, kwargs...
83+
# )
84+
# return Utils.generate_dataset(bench, dataset_size; noise_std=noise_std, kwargs...)
85+
# # (; instance_dim, nb_features, encoder) = bench
86+
# # rng = MersenneTwister(seed)
87+
# # features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
88+
# # costs = encoder.(features)
89+
# # noisy_solutions = [
90+
# # one_hot_argmax(θ + noise_std * randn(rng, Float32, instance_dim)) for θ in costs
91+
# # ]
92+
# # return [
93+
# # DataSample(; x, θ_true, y_true) for
94+
# # (x, θ_true, y_true) in zip(features, costs, noisy_solutions)
95+
# # ]
96+
# end
97+
8498
"""
8599
$TYPEDSIGNATURES
86100

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ include("Warcraft/Warcraft.jl")
5454
include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
5555
include("PortfolioOptimization/PortfolioOptimization.jl")
5656
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
57-
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
57+
# include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
5858

5959
using .Utils
6060
using .Argmax
@@ -64,10 +64,10 @@ using .Warcraft
6464
using .FixedSizeShortestPath
6565
using .PortfolioOptimization
6666
using .StochasticVehicleScheduling
67-
using .DynamicVehicleScheduling
67+
# using .DynamicVehicleScheduling
6868

6969
# Interface
70-
export AbstractBenchmark, DataSample
70+
export AbstractBenchmark, AbstractStochasticBenchmark, AbstractDynamicBenchmark, DataSample
7171
export generate_dataset
7272
export generate_statistical_model
7373
export generate_maximizer, maximizer_kwargs

src/DynamicVehicleScheduling/DynamicVSP/algorithms/prize_collecting_vsp.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ function _objective_value(θ, routes; instance)
205205
return -total, g
206206
end
207207

208-
function ChainRulesCore.rrule(::typeof(my_objective_value), θ, routes; instance)
209-
total, g = _objective_value(θ, routes; instance)
210-
function pullback(dy)
211-
g = g .* dy
212-
return NoTangent(), g, NoTangent()
213-
end
214-
return total, pullback
215-
end
208+
# function ChainRulesCore.rrule(::typeof(my_objective_value), θ, routes; instance)
209+
# total, g = _objective_value(θ, routes; instance)
210+
# function pullback(dy)
211+
# g = g .* dy
212+
# return NoTangent(), g, NoTangent()
213+
# end
214+
# return total, pullback
215+
# end

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Base: @kwdef
77
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
88
using Graphs
99
using HiGHS
10-
using InferOpt
10+
# using InferOpt
1111
using IterTools: partition
1212
using JSON
1313
using JuMP

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -103,45 +103,24 @@ function Utils.generate_maximizer(bench::FixedSizeShortestPathBenchmark; use_dij
103103
return shortest_path_maximizer
104104
end
105105

106-
"""
107-
$TYPEDSIGNATURES
108-
109-
Generate dataset for the shortest path problem.
110-
"""
111-
function Utils.generate_dataset(
112-
bench::FixedSizeShortestPathBenchmark,
113-
dataset_size::Int=10;
114-
seed::Int=0,
115-
type::Type=Float32,
106+
function Utils.generate_sample(
107+
bench::FixedSizeShortestPathBenchmark, rng::AbstractRNG; type::Type=Float32
116108
)
117-
# Set seed
118-
rng = MersenneTwister(seed)
119109
(; graph, p, deg, ν) = bench
120-
110+
features = randn(rng, Float32, bench.p)
121111
E = Graphs.ne(graph)
122-
123-
# Features
124-
features = [randn(rng, type, p) for _ in 1:dataset_size]
125-
126112
# True weights
127113
B = rand(rng, Bernoulli(0.5), E, p)
128114
ξ = if ν == 0.0
129-
[ones(type, E) for _ in 1:dataset_size]
115+
ones(type, E)
130116
else
131-
[rand(rng, Uniform{type}(1 - ν, 1 + ν), E) for _ in 1:dataset_size]
117+
rand(rng, Uniform{type}(1 - ν, 1 + ν), E)
132118
end
133-
costs = [
134-
-(1 .+ (3 .+ B * zᵢ ./ type(sqrt(p))) .^ deg) .* ξᵢ for (ξᵢ, zᵢ) in zip(ξ, features)
135-
]
136-
137-
shortest_path_maximizer = Utils.generate_maximizer(bench)
138-
139-
# Label solutions
140-
solutions = shortest_path_maximizer.(costs)
141-
return [
142-
DataSample(; x, θ_true, y_true) for
143-
(x, θ_true, y_true) in zip(features, costs, solutions)
144-
]
119+
costs = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
120+
121+
maximizer = Utils.generate_maximizer(bench)
122+
solution = maximizer(costs)
123+
return DataSample(; x=features, θ_true=costs, y_true=solution)
145124
end
146125

147126
"""

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Flux: Chain, Dense
77
using Ipopt: Ipopt
88
using JuMP: @variable, @objective, @constraint, optimize!, value, Model, set_silent
99
using LinearAlgebra: I
10-
using Random: MersenneTwister
10+
using Random: AbstractRNG, MersenneTwister
1111

1212
"""
1313
$TYPEDEF
@@ -82,6 +82,21 @@ function Utils.generate_maximizer(bench::PortfolioOptimizationBenchmark)
8282
return portfolio_maximizer
8383
end
8484

85+
function Utils.generate_sample(
86+
bench::PortfolioOptimizationBenchmark, rng::AbstractRNG; type::Type=Float32
87+
)
88+
(; d, p, deg, ν, L, f) = bench
89+
features = randn(rng, type, p, d)
90+
B = rand(rng, Bernoulli(0.5), d, p)
91+
= (0.05 / type(sqrt(p)) .* B * features .+ 0.1^(1 / deg)) .^ deg
92+
costs =.+ L * f .+ 0.01 * ν * randn(rng, type, d)
93+
94+
maximizer = Utils.generate_maximizer(bench)
95+
solution = maximizer(costs)
96+
97+
return DataSample(; x=features, θ_true=c̄, y_true=solution)
98+
end
99+
85100
"""
86101
$TYPEDSIGNATURES
87102

src/Ranking/Ranking.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,16 @@ end
6161
"""
6262
$TYPEDSIGNATURES
6363
64-
Generate a dataset of labeled instances for the ranking problem.
64+
Generate a labeled sample for the ranking problem.
6565
"""
66-
function Utils.generate_dataset(
67-
bench::RankingBenchmark, dataset_size::Int=10; seed::Int=0, noise_std=0.0
66+
function Utils.generate_sample(
67+
bench::RankingBenchmark, rng::AbstractRNG; noise_std::Float32=0.0f0
6868
)
6969
(; instance_dim, nb_features, encoder) = bench
70-
rng = MersenneTwister(seed)
71-
features = [randn(rng, Float32, nb_features, instance_dim) for _ in 1:dataset_size]
72-
costs = encoder.(features)
73-
noisy_solutions = [
74-
ranking.+ noise_std * randn(rng, Float32, instance_dim)) for θ in costs
75-
]
76-
return [
77-
DataSample(; x, θ_true, y_true) for
78-
(x, θ_true, y_true) in zip(features, costs, noisy_solutions)
79-
]
70+
features = randn(rng, Float32, nb_features, instance_dim)
71+
costs = encoder(features)
72+
noisy_solution = ranking(costs .+ noise_std * randn(rng, Float32, instance_dim))
73+
return DataSample(; x=features, θ_true=costs, y_true=noisy_solution)
8074
end
8175

8276
"""

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,32 @@ end
7373
"""
7474
$TYPEDSIGNATURES
7575
76-
Create a dataset of `dataset_size` instances for the given `StochasticVehicleSchedulingBenchmark`.
77-
If you want to not add label solutions in the dataset, set `compute_solutions=false`.
76+
Generate a sample for the given `StochasticVehicleSchedulingBenchmark`.
77+
If you want to not add label solutions in the sample, set `compute_solutions=false`.
7878
By default, they will be computed using column generation.
7979
Note that computing solutions can be time-consuming, especially for large instances.
8080
You can also use instead `compact_mip` or `compact_linearized_mip` as the algorithm to compute solutions.
8181
If you want to provide a custom algorithm to compute solutions, you can pass it as the `algorithm` keyword argument.
8282
If `algorithm` takes keyword arguments, you can pass them as well directly in `kwargs...`.
83-
If `store_city=false`, the coordinates and unnecessary information about instances will not be stored in the dataset.
83+
If `store_city=false`, the coordinates and unnecessary information about instances will not be stored in the sample.
8484
"""
85-
function Utils.generate_dataset(
85+
function Utils.generate_sample(
8686
benchmark::StochasticVehicleSchedulingBenchmark,
87-
dataset_size::Int;
87+
rng::AbstractRNG;
88+
store_city=true,
8889
compute_solutions=true,
89-
seed=nothing,
90-
rng=MersenneTwister(0),
9190
algorithm=column_generation_algorithm,
92-
store_city=true,
9391
kwargs...,
9492
)
9593
(; nb_tasks, nb_scenarios) = benchmark
96-
Random.seed!(rng, seed)
97-
instances = [
98-
Instance(; nb_tasks, nb_scenarios, rng, store_city) for _ in 1:dataset_size
99-
]
100-
features = get_features.(instances)
101-
if compute_solutions
102-
solutions = [algorithm(instance; kwargs...).value for instance in instances]
103-
return [
104-
DataSample(; x=feature, instance, y_true=solution) for
105-
(instance, feature, solution) in zip(instances, features, solutions)
106-
]
94+
instance = Instance(; nb_tasks, nb_scenarios, rng, store_city)
95+
x = get_features(instance)
96+
y_true = if compute_solutions
97+
algorithm(instance; kwargs...).value # TODO: modify algorithms to directly return the solution
98+
else
99+
nothing
107100
end
108-
# else
109-
return [
110-
DataSample(; x=feature, instance) for
111-
(instance, feature) in zip(instances, features)
112-
]
101+
return DataSample(; x, instance, y_true)
113102
end
114103

115104
"""
@@ -126,7 +115,7 @@ end
126115
$TYPEDSIGNATURES
127116
"""
128117
function Utils.generate_maximizer(
129-
bench::StochasticVehicleSchedulingBenchmark; model_builder=highs_model
118+
::StochasticVehicleSchedulingBenchmark; model_builder=highs_model
130119
)
131120
return StochasticVechicleSchedulingMaximizer(model_builder)
132121
end

src/SubsetSelection/SubsetSelection.jl

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,28 @@ without knowing their values, but only observing some features.
1717
# Fields
1818
$TYPEDFIELDS
1919
"""
20-
struct SubsetSelectionBenchmark <: AbstractBenchmark
20+
struct SubsetSelectionBenchmark{M} <: AbstractBenchmark
2121
"total number of items"
2222
n::Int
2323
"number of items to select"
2424
k::Int
25+
"hidden unknown mapping from features to costs"
26+
mapping::M
2527
end
2628

2729
function Base.show(io::IO, bench::SubsetSelectionBenchmark)
2830
(; n, k) = bench
2931
return print(io, "SubsetSelectionBenchmark(n=$n, k=$k)")
3032
end
3133

32-
function SubsetSelectionBenchmark(; n::Int=25, k::Int=5)
34+
function SubsetSelectionBenchmark(; n::Int=25, k::Int=5, identity_mapping::Bool=true)
3335
@assert n >= k "number of items n must be greater than k"
34-
return SubsetSelectionBenchmark(n, k)
36+
mapping = if identity_mapping
37+
copy
38+
else
39+
Dense(n => n; bias=false)
40+
end
41+
return SubsetSelectionBenchmark(n, k, mapping)
3542
end
3643

3744
function top_k(v::AbstractVector, k::Int)
@@ -54,29 +61,14 @@ end
5461
"""
5562
$TYPEDSIGNATURES
5663
57-
Generate a dataset of labeled instances for the subset selection problem.
58-
The mapping between features and cost is identity.
64+
Generate a labeled instance for the subset selection problem.
5965
"""
60-
function Utils.generate_dataset(
61-
bench::SubsetSelectionBenchmark,
62-
dataset_size::Int=10;
63-
seed::Int=0,
64-
identity_mapping=true,
65-
)
66-
(; n, k) = bench
67-
rng = MersenneTwister(seed)
68-
features = [randn(rng, Float32, n) for _ in 1:dataset_size]
69-
costs = if identity_mapping
70-
copy(features) # we assume that the cost is the same as the feature
71-
else
72-
mapping = Dense(n => n; bias=false)
73-
mapping.(features)
74-
end
75-
solutions = top_k.(costs, k)
76-
return [
77-
DataSample(; x, θ_true, y_true) for
78-
(x, θ_true, y_true) in zip(features, costs, solutions)
79-
]
66+
function Utils.generate_sample(bench::SubsetSelectionBenchmark, rng::AbstractRNG)
67+
(; n, k, mapping) = bench
68+
features = randn(rng, Float32, n)
69+
costs = mapping(features)
70+
solution = top_k(costs, k)
71+
return DataSample(; x=features, θ_true=costs, y_true=solution)
8072
end
8173

8274
"""

0 commit comments

Comments
 (0)