Skip to content

Commit 3828e5f

Browse files
committed
Rename DataSample fields to x/theta/y/info
1 parent e59a763 commit 3828e5f

File tree

15 files changed

+80
-217
lines changed

15 files changed

+80
-217
lines changed

src/Argmax/Argmax.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ function Utils.generate_sample(
7676
)
7777
(; instance_dim, nb_features, encoder) = bench
7878
features = randn(rng, Float32, nb_features, instance_dim)
79-
costs = encoder(features)
80-
noisy_solution = one_hot_argmax(costs + noise_std * randn(rng, Float32, instance_dim))
81-
return DataSample(; x=features, θ_true=costs, y_true=noisy_solution)
79+
θ_true = encoder(features)
80+
noisy_y_true = one_hot_argmax(θ_true + noise_std * randn(rng, Float32, instance_dim))
81+
return DataSample(; x=features, θ=θ_true, y=noisy_y_true)
8282
end
8383

8484
"""

src/Argmax2D/Argmax2D.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
6262
θ_true ./= 2 * norm(θ_true)
6363
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
6464
y_true = maximizer(θ_true; instance)
65-
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
65+
return DataSample(; x=x, θ=θ_true, y=y_true, info=instance)
6666
end
6767

6868
"""
@@ -88,11 +88,11 @@ function Utils.generate_statistical_model(
8888
return model
8989
end
9090

91-
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
91+
function Utils.plot_data(::Argmax2DBenchmark; info, θ, kwargs...)
9292
pl = init_plot()
93-
plot_polytope!(pl, instance)
93+
plot_polytope!(pl, info)
9494
plot_objective!(pl, θ)
95-
return plot_maximizer!(pl, θ, instance, maximizer)
95+
return plot_maximizer!(pl, θ, info, maximizer)
9696
end
9797

9898
"""
@@ -101,13 +101,9 @@ $TYPEDSIGNATURES
101101
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
102102
"""
103103
function Utils.plot_data(
104-
bench::Argmax2DBenchmark,
105-
sample::DataSample;
106-
instance=sample.instance,
107-
θ=sample.θ_true,
108-
kwargs...,
104+
bench::Argmax2DBenchmark, sample::DataSample; info=sample.info, θ=sample.θ, kwargs...
109105
)
110-
return Utils.plot_data(bench; instance, θ, kwargs...)
106+
return Utils.plot_data(bench; info, θ, kwargs...)
111107
end
112108

113109
export Argmax2DBenchmark

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Outputs a data sample containing an [`Instance`](@ref).
8383
function Utils.generate_sample(
8484
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0)
8585
)
86-
return DataSample(; instance=Instance(b, rng))
86+
return DataSample(; info=Instance(b, rng))
8787
end
8888

8989
"""

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
6363
dataset_size = min(dataset_size, length(files))
6464
return [
6565
DataSample(;
66-
instance=Instance(
66+
info=Instance(
6767
read_vsp_instance(files[i]);
6868
max_requests_per_epoch,
6969
Δ_dispatch,

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ function anticipative_solver(
9292
job_indices = 2:nb_nodes
9393
epoch_indices = T
9494

95-
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true)
95+
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
9696

9797
@objective(
9898
model,
9999
Max,
100100
sum(
101-
-duration[i, j] * y[i, j, t] for
102-
i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices
101+
-duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes,
102+
t in epoch_indices
103103
)
104104
)
105105

@@ -171,12 +171,14 @@ function anticipative_solver(
171171
routes = epoch_routes[i]
172172
epoch_customers = epoch_indices[i]
173173

174-
y_true = VSPSolution(
175-
Vector{Int}[
176-
map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes
177-
];
178-
max_index=length(epoch_customers),
179-
).edge_matrix
174+
y_true =
175+
VSPSolution(
176+
Vector{Int}[
177+
map(idx -> findfirst(==(idx), epoch_customers), route) for
178+
route in routes
179+
];
180+
max_index=length(epoch_customers),
181+
).edge_matrix
180182

181183
location_indices = customer_index[epoch_customers]
182184
new_coordinates = env.instance.static_instance.coordinate[location_indices]
@@ -200,8 +202,7 @@ function anticipative_solver(
200202
is_must_dispatch[2:end] .= true
201203
else
202204
is_must_dispatch[2:end] .=
203-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
204-
new_start_time[2:end]
205+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
205206
end
206207
is_postponable[2:end] .= .!is_must_dispatch[2:end]
207208
# TODO: avoid code duplication with add_new_customers!
@@ -222,7 +223,7 @@ function anticipative_solver(
222223
compute_features(state, env.instance)
223224
end
224225

225-
return DataSample(; instance=(; state, reward), y_true, x)
226+
return DataSample(; info=(; state, reward), y, x)
226227
end
227228

228229
return obj, dataset

src/FixedSizeShortestPath/FixedSizeShortestPath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ function Utils.generate_sample(
121121
else
122122
rand(rng, Uniform{type}(1 - ν, 1 + ν), E)
123123
end
124-
costs = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
124+
θ_true = -(1 .+ (3 .+ B * features ./ type(sqrt(p))) .^ deg) .* ξ
125125

126126
maximizer = Utils.generate_maximizer(bench)
127-
solution = maximizer(costs)
128-
return DataSample(; x=features, θ_true=costs, y_true=solution)
127+
y_true = maximizer(θ_true)
128+
return DataSample(; x=features, θ=θ_true, y=y_true)
129129
end
130130

131131
"""

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ function Utils.generate_sample(
9494
features = randn(rng, type, p)
9595
B = rand(rng, Bernoulli(0.5), d, p)
9696
= (0.05 / type(sqrt(p)) .* B * features .+ 0.1^(1 / deg)) .^ deg
97-
costs =.+ L * f .+ 0.01 * ν * randn(rng, type, d)
97+
θ_true =.+ L * f .+ 0.01 * ν * randn(rng, type, d)
9898

9999
maximizer = Utils.generate_maximizer(bench)
100-
solution = maximizer(costs)
100+
y_true = maximizer(θ_true)
101101

102-
return DataSample(; x=features, θ_true=costs, y_true=solution)
102+
return DataSample(; x=features, θ=θ_true, y=y_true)
103103
end
104104

105105
"""

src/Ranking/Ranking.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ function Utils.generate_sample(
6868
)
6969
(; instance_dim, nb_features, encoder) = bench
7070
features = randn(rng, Float32, nb_features, instance_dim)
71-
costs = encoder(features)
72-
noisy_solution = ranking(costs .+ noise_std * randn(rng, Float32, instance_dim))
73-
return DataSample(; x=features, θ_true=costs, y_true=noisy_solution)
71+
θ_true = encoder(features)
72+
noisy_y_true = ranking(θ_true .+ noise_std * randn(rng, Float32, instance_dim))
73+
return DataSample(; x=features, θ=θ_true, y=noisy_y_true)
7474
end
7575

7676
"""

src/ShortestPath/ShortestPath.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/ShortestPath/shortest_paths.jl

Lines changed: 0 additions & 130 deletions
This file was deleted.

0 commit comments

Comments
 (0)