Skip to content

Commit 4060caf

Browse files
authored
Merge pull request #44 from JuliaDecisionFocusedLearning/rename-datasample-fields
Rename DataSample fields
2 parents 7f6788e + fb7977a commit 4060caf

33 files changed

+170
-294
lines changed

docs/src/benchmark_interfaces.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ All benchmarks work with [`DataSample`](@ref) objects that encapsulate the data
1111

1212
```julia
1313
@kwdef struct DataSample{I,F,S,C}
14-
x::F = nothing # Input features
15-
θ_true::C = nothing # True cost/utility parameters
16-
y_true::S = nothing # True optimal solution
17-
instance::I = nothing # Problem instance object/additional data
14+
x::F = nothing # Input features of the policy
15+
θ::C = nothing # Intermediate cost/utility parameters
16+
y::S = nothing # Output solution
17+
info::I = nothing # Additional data information (e.g., problem instance)
1818
end
1919
```
2020

21-
The `DataSample` provides flexibility - not all fields need to be populated depending on the benchmark type and use case.
21+
The `DataSample` provides flexibility, not all fields need to be populated depending on the benchmark type and use.
2222

2323
### Benchmark Type Hierarchy
2424

docs/src/tutorials/warcraft_tutorial.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ dataset = generate_dataset(b, 50);
2121
# Subdatasets can be created through regular slicing:
2222
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
2323

24-
# And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `instance`, `θ`, and `y`:
24+
# And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `info`, `θ`, and `y`:
2525
sample = test_dataset[1]
2626
# `x` correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:
2727
x = sample.x
28-
# `θ_true` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:
29-
θ_true = sample.θ_true
30-
# `y_true` correspond to the optimal shortest path, encoded as a binary matrix:
31-
y_true = sample.y_true
32-
# `instance` is not used in this benchmark, therefore set to nothing:
33-
isnothing(sample.instance)
28+
# `θ` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:
29+
θ_true = sample.θ
30+
# `y` correspond to the optimal shortest path, encoded as a binary matrix:
31+
y_true = sample.y
32+
# `info` is not used in this benchmark, therefore set to nothing:
33+
isnothing(sample.info)
3434

3535
# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
3636
plot_data(b, sample)
@@ -50,7 +50,7 @@ maximizer = generate_maximizer(b; dijkstra=true)
5050
# In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
5151
y = maximizer(θ)
5252
# As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
53-
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
53+
plot_data(b, DataSample(; x, θ, y))
5454
# We can evaluate the current pipeline performance using the optimality gap metric:
5555
starting_gap = compute_gap(b, test_dataset, model, maximizer)
5656

@@ -70,7 +70,7 @@ opt_state = Flux.setup(Adam(1e-3), model)
7070
loss_history = Float64[]
7171
for epoch in 1:50
7272
val, grads = Flux.withgradient(model) do m
73-
sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset)
73+
sum(loss(m(x), y) for (; x, y) in train_dataset) / length(train_dataset)
7474
end
7575
Flux.update!(opt_state, model, grads[1])
7676
push!(loss_history, val)
@@ -85,7 +85,7 @@ final_gap = compute_gap(b, test_dataset, model, maximizer)
8585
#
8686
θ = model(x)
8787
y = maximizer(θ)
88-
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
88+
plot_data(b, DataSample(; x, θ, y))
8989

9090
using Test #src
9191
@test final_gap < starting_gap #src

src/Argmax/Argmax.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ function Utils.generate_sample(
7676
)
7777
(; instance_dim, nb_features, encoder) = bench
7878
features = randn(rng, Float32, nb_features, instance_dim)
79-
costs = encoder(features)
80-
noisy_solution = one_hot_argmax(costs + noise_std * randn(rng, Float32, instance_dim))
81-
return DataSample(; x=features, θ_true=costs, y_true=noisy_solution)
79+
θ_true = encoder(features)
80+
noisy_y_true = one_hot_argmax(θ_true + noise_std * randn(rng, Float32, instance_dim))
81+
return DataSample(; x=features, θ=θ_true, y=noisy_y_true)
8282
end
8383

8484
"""

src/Argmax2D/Argmax2D.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
6262
θ_true ./= 2 * norm(θ_true)
6363
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
6464
y_true = maximizer(θ_true; instance)
65-
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
65+
return DataSample(; x=x, θ=θ_true, y=y_true, info=instance)
6666
end
6767

6868
"""
@@ -88,11 +88,11 @@ function Utils.generate_statistical_model(
8888
return model
8989
end
9090

91-
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
91+
function Utils.plot_data(::Argmax2DBenchmark; info, θ, kwargs...)
9292
pl = init_plot()
93-
plot_polytope!(pl, instance)
93+
plot_polytope!(pl, info)
9494
plot_objective!(pl, θ)
95-
return plot_maximizer!(pl, θ, instance, maximizer)
95+
return plot_maximizer!(pl, θ, info, maximizer)
9696
end
9797

9898
"""
@@ -101,13 +101,9 @@ $TYPEDSIGNATURES
101101
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
102102
"""
103103
function Utils.plot_data(
104-
bench::Argmax2DBenchmark,
105-
sample::DataSample;
106-
instance=sample.instance,
107-
θ=sample.θ_true,
108-
kwargs...,
104+
bench::Argmax2DBenchmark, sample::DataSample; info=sample.info, θ=sample.θ, kwargs...
109105
)
110-
return Utils.plot_data(bench; instance, θ, kwargs...)
106+
return Utils.plot_data(bench; info, θ, kwargs...)
111107
end
112108

113109
export Argmax2DBenchmark

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Outputs a data sample containing an [`Instance`](@ref).
8383
function Utils.generate_sample(
8484
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0)
8585
)
86-
return DataSample(; instance=Instance(b, rng))
86+
return DataSample(; info=Instance(b, rng))
8787
end
8888

8989
"""

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
6363
dataset_size = min(dataset_size, length(files))
6464
return [
6565
DataSample(;
66-
instance=Instance(
66+
info=Instance(
6767
read_vsp_instance(files[i]);
6868
max_requests_per_epoch,
6969
Δ_dispatch,

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function anticipative_solver(
222222
compute_features(state, env.instance)
223223
end
224224

225-
return DataSample(; instance=(; state, reward), y_true, x)
225+
return DataSample(; info=(; state, reward), y=y_true, x)
226226
end
227227

228228
return obj, dataset

src/DynamicVehicleScheduling/plot.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ The returned dictionary contains:
208208
This lets plotting code build figures without depending on plotting internals.
209209
"""
210210
function build_plot_data(data_samples::Vector{<:DataSample})
211-
state_data = [build_state_data(sample.instance.state) for sample in data_samples]
212-
rewards = [sample.instance.reward for sample in data_samples]
213-
routess = [sample.y_true for sample in data_samples]
211+
state_data = [build_state_data(sample.info.state) for sample in data_samples]
212+
rewards = [sample.info.reward for sample in data_samples]
213+
routess = [sample.y for sample in data_samples]
214214
return [
215215
(; state..., reward, routes) for
216216
(state, reward, routes) in zip(state_data, rewards, routess)
@@ -273,8 +273,8 @@ function plot_epochs(
273273
# Create subplots
274274
plots = map(1:n_epochs) do i
275275
sample = data_samples[i]
276-
state = sample.instance.state
277-
reward = sample.instance.reward
276+
state = sample.info.state
277+
reward = sample.info.reward
278278

279279
common_kwargs = Dict(
280280
:xlims => xlims,
@@ -292,7 +292,7 @@ function plot_epochs(
292292
if plot_routes_flag
293293
fig = plot_routes(
294294
state,
295-
sample.y_true;
295+
sample.y;
296296
reward=reward,
297297
show_route_labels=false,
298298
common_kwargs...,
@@ -351,7 +351,7 @@ function animate_epochs(
351351
kwargs...,
352352
)
353353
pd = build_plot_data(data_samples)
354-
epoch_costs = [-sample.instance.reward for sample in data_samples]
354+
epoch_costs = [-sample.info.reward for sample in data_samples]
355355

356356
# Calculate global xlims and ylims from all states
357357
x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd)
@@ -393,12 +393,12 @@ function animate_epochs(
393393
anim = @animate for frame_idx in 1:total_frames
394394
epoch_idx, frame_type = frame_plan[frame_idx]
395395
sample = data_samples[epoch_idx]
396-
state = sample.instance.state
396+
state = sample.info.state
397397

398398
if frame_type == :routes
399399
fig = plot_routes(
400400
state,
401-
sample.y_true;
401+
sample.y;
402402
xlims=xlims,
403403
ylims=ylims,
404404
clims=clims,

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ function Utils.generate_sample(
121121
else
122122
rand(rng, Uniform{type}(1 - ν, 1 + ν), E)
123123
end
124-
costs = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
124+
θ_true = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
125125

126126
maximizer = Utils.generate_maximizer(bench)
127-
solution = maximizer(costs)
128-
return DataSample(; x=features, θ_true=costs, y_true=solution)
127+
y_true = maximizer(θ_true)
128+
return DataSample(; x=features, θ=θ_true, y=y_true)
129129
end
130130

131131
"""

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ function Utils.generate_sample(
9494
features = randn(rng, type, p)
9595
B = rand(rng, Bernoulli(0.5), d, p)
9696
= (0.05 / type(sqrt(p)) .* B * features .+ 0.1^(1 / deg)) .^ deg
97-
costs =.+ L * f .+ 0.01 * ν * randn(rng, type, d)
97+
θ_true =.+ L * f .+ 0.01 * ν * randn(rng, type, d)
9898

9999
maximizer = Utils.generate_maximizer(bench)
100-
solution = maximizer(costs)
100+
y_true = maximizer(θ_true)
101101

102-
return DataSample(; x=features, θ_true=costs, y_true=solution)
102+
return DataSample(; x=features, θ=θ_true, y=y_true)
103103
end
104104

105105
"""

0 commit comments

Comments
 (0)