Skip to content

Commit ecb4054

Browse files
committed
replace InferOptDataset by Vector{DataSample}
1 parent e381b59 commit ecb4054

File tree

12 files changed

+34
-64
lines changed

12 files changed

+34
-64
lines changed

docs/src/tutorials/warcraft.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ model = generate_statistical_model(b)
4545
θ = model(x)
4646
# Note that the model is not trained yet, and its parameters are randomly initialized.
4747

48-
# Finally, the [`generate_maximizer`](@ref) method can be used generates a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:
48+
# Finally, the [`generate_maximizer`](@ref) method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:
4949
maximizer = generate_maximizer(b; dijkstra=true)
5050
# In the case o fthe Warcraft benchmark, the method has an additioonal keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
5151
y = maximizer(θ)

src/FixedSizeShortestPath/shortest_paths.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function Utils.generate_dataset(
121121

122122
# Label solutions
123123
solutions = shortest_path_maximizer.(.-costs)
124-
return InferOptDataset(; features, costs, solutions)
124+
return [DataSample(; x=x, θ=θ, y=y) for (x, θ, y) in zip(features, costs, solutions)]
125125
end
126126

127127
"""

src/InferOptBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using .PortfolioOptimization
3131
using .SubsetSelection
3232

3333
# Interface
34-
export AbstractBenchmark, InferOptDataset
34+
export AbstractBenchmark, DataSample
3535
export generate_dataset
3636
export generate_statistical_model
3737
export generate_maximizer

src/PortfolioOptimization/portfolio_optimization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function Utils.generate_dataset(
9999
maximizer = Utils.generate_maximizer(bench)
100100
solutions = maximizer.(costs)
101101

102-
return InferOptDataset(; features, costs, solutions)
102+
return [DataSample(; x=x, θ=θ, y=y) for (x, θ, y) in zip(features, costs, solutions)]
103103
end
104104

105105
"""

src/ShortestPath/shortest_paths.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function Utils.generate_dataset(
116116

117117
# Label solutions
118118
solutions = shortest_path_maximizer.(.-costs)
119-
return InferOptDataset(; features, costs, solutions)
119+
return [DataSample(; x=x, θ=θ, y=y) for (x, θ, y) in zip(features, costs, solutions)]
120120
end
121121

122122
"""

src/SubsetSelection/subset_selection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function Utils.generate_dataset(
5858
features = [randn(rng, Float32, n) for _ in 1:dataset_size]
5959
costs = copy(features) # we assume that the cost is the same as the feature
6060
solutions = top_k.(features, k)
61-
return InferOptDataset(; features, solutions, costs)
61+
return [DataSample(; x=x, θ=θ, y=y) for (x, θ, y) in zip(features, costs, solutions)]
6262
end
6363

6464
"""

src/Utils/Utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ using Flux: softplus
55
using LinearAlgebra: dot
66
using SimpleWeightedGraphs: SimpleWeightedDiGraph
77

8-
include("dataset.jl")
8+
include("data_sample.jl")
99
include("interface.jl")
1010
include("grid_graph.jl")
1111
include("misc.jl")
1212

13-
export InferOptDataset
13+
export DataSample
1414

1515
export AbstractBenchmark
1616
export generate_dataset,

src/Utils/data_sample.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
$TYPEDEF
3+
4+
Data sample data structure.
5+
6+
# Fields
7+
$TYPEDFIELDS
8+
"""
9+
@kwdef struct DataSample{F,S,C,I}
10+
"features"
11+
x::F
12+
"costs"
13+
θ::C = nothing
14+
"solution"
15+
y::S = nothing
16+
"instance"
17+
instance::I = nothing
18+
end

src/Utils/dataset.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/Utils/interface.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,14 @@ $TYPEDSIGNATURES
6767
Default behaviour of `compute_gap` for a benchmark problem where `features`, `solutions` and `costs` are all defined.
6868
"""
6969
function compute_gap(
70-
bench::AbstractBenchmark,
71-
dataset::InferOptDataset{<:AbstractArray,<:AbstractArray,<:AbstractArray},
72-
statistical_model,
73-
maximizer,
70+
bench::AbstractBenchmark, dataset::Vector{<:DataSample}, statistical_model, maximizer
7471
)
7572
res = 0.0
76-
X = dataset.features
77-
costs = dataset.costs
78-
Y = dataset.solutions
7973

80-
for (x, θ̄, ȳ) in zip(X, costs, Y)
74+
for sample in dataset
75+
x = sample.x
76+
θ̄ = sample.θ
77+
= sample.y
8178
θ = statistical_model(x)
8279
y = maximizer(θ)
8380
target_obj = objective_value(bench, θ̄, ȳ)

0 commit comments

Comments
 (0)