Skip to content

Commit 0168942

Browse files
committed
Adjust tests, and fix errors
1 parent 3828e5f commit 0168942

File tree

18 files changed

+100
-87
lines changed

18 files changed

+100
-87
lines changed

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ function anticipative_solver(
9292
job_indices = 2:nb_nodes
9393
epoch_indices = T
9494

95-
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
95+
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true)
9696

9797
@objective(
9898
model,
9999
Max,
100100
sum(
101-
-duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes,
102-
t in epoch_indices
101+
-duration[i, j] * y[i, j, t] for
102+
i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices
103103
)
104104
)
105105

@@ -171,14 +171,12 @@ function anticipative_solver(
171171
routes = epoch_routes[i]
172172
epoch_customers = epoch_indices[i]
173173

174-
y_true =
175-
VSPSolution(
176-
Vector{Int}[
177-
map(idx -> findfirst(==(idx), epoch_customers), route) for
178-
route in routes
179-
];
180-
max_index=length(epoch_customers),
181-
).edge_matrix
174+
y_true = VSPSolution(
175+
Vector{Int}[
176+
map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes
177+
];
178+
max_index=length(epoch_customers),
179+
).edge_matrix
182180

183181
location_indices = customer_index[epoch_customers]
184182
new_coordinates = env.instance.static_instance.coordinate[location_indices]
@@ -202,7 +200,8 @@ function anticipative_solver(
202200
is_must_dispatch[2:end] .= true
203201
else
204202
is_must_dispatch[2:end] .=
205-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
203+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
204+
new_start_time[2:end]
206205
end
207206
is_postponable[2:end] .= .!is_must_dispatch[2:end]
208207
# TODO: avoid code duplication with add_new_customers!
@@ -223,7 +222,7 @@ function anticipative_solver(
223222
compute_features(state, env.instance)
224223
end
225224

226-
return DataSample(; info=(; state, reward), y, x)
225+
return DataSample(; info=(; state, reward), y=y_true, x)
227226
end
228227

229228
return obj, dataset

src/DynamicVehicleScheduling/plot.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ The returned dictionary contains:
208208
This lets plotting code build figures without depending on plotting internals.
209209
"""
210210
function build_plot_data(data_samples::Vector{<:DataSample})
211-
state_data = [build_state_data(sample.instance.state) for sample in data_samples]
212-
rewards = [sample.instance.reward for sample in data_samples]
213-
routess = [sample.y_true for sample in data_samples]
211+
state_data = [build_state_data(sample.info.state) for sample in data_samples]
212+
rewards = [sample.info.reward for sample in data_samples]
213+
routess = [sample.y for sample in data_samples]
214214
return [
215215
(; state..., reward, routes) for
216216
(state, reward, routes) in zip(state_data, rewards, routess)
@@ -273,8 +273,8 @@ function plot_epochs(
273273
# Create subplots
274274
plots = map(1:n_epochs) do i
275275
sample = data_samples[i]
276-
state = sample.instance.state
277-
reward = sample.instance.reward
276+
state = sample.info.state
277+
reward = sample.info.reward
278278

279279
common_kwargs = Dict(
280280
:xlims => xlims,
@@ -292,7 +292,7 @@ function plot_epochs(
292292
if plot_routes_flag
293293
fig = plot_routes(
294294
state,
295-
sample.y_true;
295+
sample.y;
296296
reward=reward,
297297
show_route_labels=false,
298298
common_kwargs...,
@@ -351,7 +351,7 @@ function animate_epochs(
351351
kwargs...,
352352
)
353353
pd = build_plot_data(data_samples)
354-
epoch_costs = [-sample.instance.reward for sample in data_samples]
354+
epoch_costs = [-sample.info.reward for sample in data_samples]
355355

356356
# Calculate global xlims and ylims from all states
357357
x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd)
@@ -393,12 +393,12 @@ function animate_epochs(
393393
anim = @animate for frame_idx in 1:total_frames
394394
epoch_idx, frame_type = frame_plan[frame_idx]
395395
sample = data_samples[epoch_idx]
396-
state = sample.instance.state
396+
state = sample.info.state
397397

398398
if frame_type == :routes
399399
fig = plot_routes(
400400
state,
401-
sample.y_true;
401+
sample.y;
402402
xlims=xlims,
403403
ylims=ylims,
404404
clims=clims,

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ function plot_solution(
208208
)
209209
(; tasks, district_width, width) = sample.info.city
210210
ticks = 0:district_width:width
211-
solution = Solution(sample.y_true, sample.info)
211+
solution = Solution(sample.y, sample.info)
212212
path_list = compute_path_list(solution)
213213
fig = plot(;
214214
xlabel="x",

src/Utils/policy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function evaluate_policy!(
4444
features, state = observe(env)
4545
state_copy = deepcopy(state) # To avoid mutation issues
4646
reward = step!(env, y)
47-
sample = DataSample(; x=features, y_true=y, instance=(; state=state_copy, reward))
47+
sample = DataSample(; x=features, y=y, info=(; state=state_copy, reward))
4848
if @isdefined labeled_dataset
4949
push!(labeled_dataset, sample)
5050
else

src/Warcraft/Warcraft.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ The keyword argument `θ_true` is used to set the color range of the weights plo
9191
function Utils.plot_data(
9292
::WarcraftBenchmark,
9393
sample::DataSample;
94-
θ_true=sample.θ_true,
94+
θ_true=sample.θ,
9595
θ_title="Weights",
9696
y_title="Path",
9797
kwargs...,
9898
)
9999
x = sample.x
100-
y = sample.y_true
101-
θ = sample.θ_true
100+
y = sample.y
101+
θ = sample.θ
102102
im = dropdims(x; dims=4)
103103
img = convert_image_for_plot(im)
104104
p1 = Plots.plot(

src/Warcraft/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function create_dataset(decompressed_path::String, nb_samples::Int)
4040
]
4141
Y = [BitMatrix(terrain_labels[:, :, i]) for i in 1:N]
4242
WG = [-terrain_weights[:, :, i] for i in 1:N]
43-
return [DataSample(; x, y_true, θ_true) for (x, y_true, θ_true) in zip(X, Y, WG)]
43+
return [DataSample(; x, y=y_true, θ=θ_true) for (x, y_true, θ_true) in zip(X, Y, WG)]
4444
end
4545

4646
"""

test/argmax.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
@test gap >= 0
1919

2020
for (i, sample) in enumerate(dataset)
21-
(; x, θ_true, y_true) = sample
21+
x = sample.x
22+
θ_true = sample.θ
23+
y_true = sample.y
2224
@test size(x) == (nb_features, instance_dim)
2325
@test length(θ_true) == instance_dim
2426
@test length(y_true) == instance_dim
25-
@test isnothing(sample.instance)
27+
@test isnothing(sample.info)
2628
@test all(y_true .== maximizer(θ_true))
2729

2830
θ = model(x)

test/argmax_2d.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
@test figure isa Plots.Plot
2222

2323
for (i, sample) in enumerate(dataset)
24-
(; x, θ_true, y_true, instance) = sample
24+
x = sample.x
25+
θ_true = sample.θ
26+
y_true = sample.y
27+
instance = sample.info
2528
@test length(x) == nb_features
2629
@test length(θ_true) == 2
2730
@test length(y_true) == 2
28-
@test !isnothing(sample.instance)
31+
@test !isnothing(instance)
2932
@test instance isa Vector{Vector{Float64}}
3033
@test all(length(vertex) == 2 for vertex in instance)
3134
@test y_true in instance

test/dynamic_assortment.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
const DAP = DecisionFocusedLearningBenchmarks.DynamicAssortment
33
end
44

5-
@testitem "DynamicAssortment - Benchmark Construction" setup=[Imports, DAPSetup] begin
5+
@testitem "DynamicAssortment - Benchmark Construction" setup = [Imports, DAPSetup] begin
66
# Test default constructor
77
b = DynamicAssortmentBenchmark()
88
@test b.N == 20
@@ -13,7 +13,7 @@ end
1313
@test !is_exogenous(b)
1414

1515
# Test custom constructor
16-
b_custom = DynamicAssortmentBenchmark(N=10, d=3, K=2, max_steps=50, exogenous=true)
16+
b_custom = DynamicAssortmentBenchmark(; N=10, d=3, K=2, max_steps=50, exogenous=true)
1717
@test b_custom.N == 10
1818
@test b_custom.d == 3
1919
@test b_custom.K == 2
@@ -28,8 +28,8 @@ end
2828
@test DAP.max_steps(b) == 80
2929
end
3030

31-
@testitem "DynamicAssortment - Instance Generation" setup=[Imports, DAPSetup] begin
32-
b = DynamicAssortmentBenchmark(N=5, d=3, K=2)
31+
@testitem "DynamicAssortment - Instance Generation" setup = [Imports, DAPSetup] begin
32+
b = DynamicAssortmentBenchmark(; N=5, d=3, K=2)
3333
rng = MersenneTwister(42)
3434

3535
instance = DAP.Instance(b, rng)
@@ -53,8 +53,8 @@ end
5353
@test DAP.prices(instance) == instance.prices
5454
end
5555

56-
@testitem "DynamicAssortment - Environment Initialization" setup=[Imports, DAPSetup] begin
57-
b = DynamicAssortmentBenchmark(N=5, d=2, K=2, max_steps=10)
56+
@testitem "DynamicAssortment - Environment Initialization" setup = [Imports, DAPSetup] begin
57+
b = DynamicAssortmentBenchmark(; N=5, d=2, K=2, max_steps=10)
5858
instance = DAP.Instance(b, MersenneTwister(42))
5959

6060
env = DAP.Environment(instance; seed=123)
@@ -80,8 +80,8 @@ end
8080
@test DAP.prices(env) == instance.prices
8181
end
8282

83-
@testitem "DynamicAssortment - Environment Reset" setup=[Imports, DAPSetup] begin
84-
b = DynamicAssortmentBenchmark(N=3, d=1, K=2, max_steps=5)
83+
@testitem "DynamicAssortment - Environment Reset" setup = [Imports, DAPSetup] begin
84+
b = DynamicAssortmentBenchmark(; N=3, d=1, K=2, max_steps=5)
8585
instance = DAP.Instance(b, MersenneTwister(42))
8686
env = DAP.Environment(instance; seed=123)
8787

@@ -107,8 +107,8 @@ end
107107
@test env.features expected_features
108108
end
109109

110-
@testitem "DynamicAssortment - Hype Update Logic" setup=[Imports, DAPSetup] begin
111-
b = DynamicAssortmentBenchmark(N=5, d=1, K=2)
110+
@testitem "DynamicAssortment - Hype Update Logic" setup = [Imports, DAPSetup] begin
111+
b = DynamicAssortmentBenchmark(; N=5, d=1, K=2)
112112
instance = DAP.Instance(b, MersenneTwister(42))
113113
env = DAP.Environment(instance; seed=123)
114114

@@ -135,8 +135,8 @@ end
135135
@test all(hype .== 1.0) # Should not affect any item hype
136136
end
137137

138-
@testitem "DynamicAssortment - Choice Probabilities" setup=[Imports, DAPSetup] begin
139-
b = DynamicAssortmentBenchmark(N=3, d=1, K=2)
138+
@testitem "DynamicAssortment - Choice Probabilities" setup = [Imports, DAPSetup] begin
139+
b = DynamicAssortmentBenchmark(; N=3, d=1, K=2)
140140
instance = DAP.Instance(b, MersenneTwister(42))
141141
env = DAP.Environment(instance; seed=123)
142142

@@ -167,8 +167,8 @@ end
167167
@test probs[4] 1.0 # Only no-purchase available
168168
end
169169

170-
@testitem "DynamicAssortment - Expected Revenue" setup=[Imports, DAPSetup] begin
171-
b = DynamicAssortmentBenchmark(N=3, d=1, K=2)
170+
@testitem "DynamicAssortment - Expected Revenue" setup = [Imports, DAPSetup] begin
171+
b = DynamicAssortmentBenchmark(; N=3, d=1, K=2)
172172
instance = DAP.Instance(b, MersenneTwister(42))
173173
env = DAP.Environment(instance; seed=123)
174174

@@ -183,8 +183,8 @@ end
183183
@test revenue == 0.0 # Only no-purchase available with price 0
184184
end
185185

186-
@testitem "DynamicAssortment - Environment Step" setup=[Imports, DAPSetup] begin
187-
b = DynamicAssortmentBenchmark(N=3, d=1, K=2, max_steps=5)
186+
@testitem "DynamicAssortment - Environment Step" setup = [Imports, DAPSetup] begin
187+
b = DynamicAssortmentBenchmark(; N=3, d=1, K=2, max_steps=5)
188188
instance = DAP.Instance(b, MersenneTwister(42))
189189
env = DAP.Environment(instance; seed=123)
190190

@@ -219,9 +219,9 @@ end
219219
@test_throws AssertionError step!(env, assortment)
220220
end
221221

222-
@testitem "DynamicAssortment - Endogenous vs Exogenous" setup=[Imports, DAPSetup] begin
222+
@testitem "DynamicAssortment - Endogenous vs Exogenous" setup = [Imports, DAPSetup] begin
223223
# Test endogenous environment (features change with purchases)
224-
b_endo = DynamicAssortmentBenchmark(N=3, d=1, K=2, exogenous=false)
224+
b_endo = DynamicAssortmentBenchmark(; N=3, d=1, K=2, exogenous=false)
225225
instance_endo = DAP.Instance(b_endo, MersenneTwister(42))
226226
env_endo = DAP.Environment(instance_endo; seed=123)
227227

@@ -232,7 +232,7 @@ end
232232
@test any(env_endo.d_features .!= 0.0) # Delta features should be non-zero
233233

234234
# Test exogenous environment (features don't change with purchases)
235-
b_exo = DynamicAssortmentBenchmark(N=3, d=1, K=2, exogenous=true)
235+
b_exo = DynamicAssortmentBenchmark(; N=3, d=1, K=2, exogenous=true)
236236
instance_exo = DAP.Instance(b_exo, MersenneTwister(42))
237237
env_exo = DAP.Environment(instance_exo; seed=123)
238238

@@ -243,8 +243,8 @@ end
243243
@test all(env_exo.d_features .== 0.0) # Delta features should remain zero
244244
end
245245

246-
@testitem "DynamicAssortment - Observation" setup=[Imports, DAPSetup] begin
247-
b = DynamicAssortmentBenchmark(N=3, d=2, max_steps=10)
246+
@testitem "DynamicAssortment - Observation" setup = [Imports, DAPSetup] begin
247+
b = DynamicAssortmentBenchmark(; N=3, d=2, max_steps=10)
248248
instance = DAP.Instance(b, MersenneTwister(42))
249249
env = DAP.Environment(instance; seed=123)
250250

@@ -266,10 +266,10 @@ end
266266
@test obs1 != obs2 # Observations should differ after purchase
267267
end
268268

269-
@testitem "DynamicAssortment - Policies" setup=[Imports, DAPSetup] begin
269+
@testitem "DynamicAssortment - Policies" setup = [Imports, DAPSetup] begin
270270
using Statistics: mean
271271

272-
b = DynamicAssortmentBenchmark(N=5, d=2, K=3, max_steps=20)
272+
b = DynamicAssortmentBenchmark(; N=5, d=2, K=3, max_steps=20)
273273

274274
# Generate test data
275275
dataset = generate_dataset(b, 10; seed=0)
@@ -307,8 +307,8 @@ end
307307
@test sum(greedy_action) == DAP.assortment_size(env)
308308
end
309309

310-
@testitem "DynamicAssortment - Model and Maximizer Integration" setup=[Imports, DAPSetup] begin
311-
b = DynamicAssortmentBenchmark(N=4, d=3, K=2)
310+
@testitem "DynamicAssortment - Model and Maximizer Integration" setup = [Imports, DAPSetup] begin
311+
b = DynamicAssortmentBenchmark(; N=4, d=3, K=2)
312312

313313
# Test statistical model generation
314314
model = generate_statistical_model(b; seed=42)
@@ -317,7 +317,7 @@ end
317317

318318
# Test integration with sample data
319319
sample = generate_sample(b, MersenneTwister(42))
320-
@test hasfield(typeof(sample), :instance)
320+
@test hasfield(typeof(sample), :info)
321321

322322
dataset = generate_dataset(b, 3; seed=42)
323323
environments = generate_environments(b, dataset)

test/dynamic_vsp.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@test mean(r_lazy) <= mean(r_greedy)
2727

2828
env = environments[1]
29-
instance = dataset[1].instance
29+
instance = dataset[1].info
3030
scenario = generate_scenario(b, instance)
3131
v, y = generate_anticipative_solution(b, env, scenario; nb_epochs=2, reset_env=true)
3232

@@ -49,6 +49,8 @@
4949

5050
anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true)
5151
reset!(env; reset_rng=true)
52-
cost = sum(step!(env, sample.y_true) for sample in solution)
52+
cost = sum(step!(env, sample.y) for sample in solution)
53+
cost2 = sum(sample.info.reward for sample in solution)
5354
@test isapprox(cost, anticipative_value; atol=1e-5)
55+
@test isapprox(cost, cost2; atol=1e-5)
5456
end

0 commit comments

Comments
 (0)