Skip to content

Commit 39a9795

Browse files
authored
Merge pull request #39 from JuliaDecisionFocusedLearning/fix-dvsp
Fix DVSP anticipative policy
2 parents 976335d + 3fe0027 commit 39a9795

File tree

7 files changed

+103
-73
lines changed

7 files changed

+103
-73
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ tutorial_files = readdir(tutorial_dir)
1313
md_tutorial_files = [split(file, ".")[1] * ".md" for file in tutorial_files]
1414
benchmark_files = [joinpath("benchmarks", e) for e in readdir(benchmarks_dir)]
1515

16-
include_tutorial = false
16+
include_tutorial = true
1717

1818
if include_tutorial
1919
for file in tutorial_files

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ function Utils.generate_policies(b::DynamicVehicleSchedulingBenchmark)
111111
return (lazy, greedy)
112112
end
113113

114-
function Utils.generate_statistical_model(b::DynamicVehicleSchedulingBenchmark)
114+
function Utils.generate_statistical_model(
115+
b::DynamicVehicleSchedulingBenchmark; seed=nothing
116+
)
117+
Random.seed!(seed)
115118
return Chain(Dense((b.two_dimensional_features ? 2 : 14) => 1), vec)
116119
end
117120

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,44 @@ function anticipative_solver(
4848
model_builder=highs_model,
4949
two_dimensional_features=env.instance.two_dimensional_features,
5050
reset_env=true,
51-
nb_epochs=typemax(Int),
51+
nb_epochs=nothing,
5252
seed=get_seed(env),
53+
verbose=false,
5354
)
5455
if reset_env
5556
reset!(env; reset_rng=true, seed)
5657
end
5758

59+
@assert !is_terminated(env)
60+
5861
start_epoch = current_epoch(env)
59-
end_epoch = min(last_epoch(env), start_epoch + nb_epochs - 1)
62+
end_epoch = if isnothing(nb_epochs)
63+
last_epoch(env)
64+
else
65+
min(last_epoch(env), start_epoch + nb_epochs - 1)
66+
end
6067
T = start_epoch:end_epoch
68+
TT = (start_epoch + 1):end_epoch # horizon without start epoch
69+
70+
starting_state = deepcopy(env.state)
6171

6272
request_epoch = [0]
63-
for t in T
73+
request_epoch = vcat(request_epoch, fill(start_epoch, customer_count(starting_state)))
74+
for t in TT
6475
request_epoch = vcat(request_epoch, fill(t, length(scenario.indices[t])))
6576
end
66-
customer_index = vcat(1, scenario.indices[T]...)
67-
service_time = vcat(0.0, scenario.service_time[T]...)
68-
start_time = vcat(0.0, scenario.start_time[T]...)
77+
78+
customer_index = vcat(starting_state.location_indices, scenario.indices[TT]...)
79+
service_time = vcat(
80+
starting_state.state_instance.service_time, scenario.service_time[TT]...
81+
)
82+
start_time = vcat(starting_state.state_instance.start_time, scenario.start_time[TT]...)
6983

7084
duration = env.instance.static_instance.duration[customer_index, customer_index]
7185
(; epoch_duration, Δ_dispatch) = env.instance
7286

7387
model = model_builder()
74-
set_silent(model)
88+
verbose || set_silent(model)
7589

7690
nb_nodes = length(customer_index)
7791
job_indices = 2:nb_nodes
@@ -136,29 +150,25 @@ function anticipative_solver(
136150
value.(y), env, customer_index, epoch_indices
137151
)
138152

139-
epoch_indices = Vector{Int}[]
140-
N = 1
141-
indices = [1]
142153
index = 1
143-
for epoch in 1:last_epoch(env)
144-
M = length(scenario.indices[epoch])
145-
indices = vcat(indices, (N + 1):(N + M))
146-
push!(epoch_indices, copy(indices))
154+
indices = collect(1:(customer_count(starting_state) + 1)) # current known indices in global indexing
155+
epoch_indices = [indices] # store global indices present at each epoch
156+
N = length(indices) # current last index known in global indexing
157+
for epoch in TT # 1:last_epoch(env)
158+
# remove dispatched customers from indices
159+
dispatched = vcat(epoch_routes[index]...)
160+
indices = setdiff(indices, dispatched)
161+
162+
M = length(scenario.indices[epoch]) # number of new customers in epoch
163+
indices = vcat(indices, (N + 1):(N + M)) # add global indices of customers in epoch
164+
push!(epoch_indices, copy(indices)) # store global indices present at each epoch
147165
N = N + M
148-
if epoch in T
149-
dispatched = vcat(epoch_routes[index]...)
150-
index += 1
151-
indices = setdiff(indices, dispatched)
152-
end
166+
index += 1
153167
end
154168

155-
indices = vcat(1, scenario.indices...)
156-
start_time = vcat(0.0, scenario.start_time...)
157-
service_time = vcat(0.0, scenario.service_time...)
158-
159169
dataset = map(enumerate(T)) do (i, epoch)
160170
routes = epoch_routes[i]
161-
epoch_customers = epoch_indices[epoch]
171+
epoch_customers = epoch_indices[i]
162172

163173
y_true = VSPSolution(
164174
Vector{Int}[
@@ -167,7 +177,7 @@ function anticipative_solver(
167177
max_index=length(epoch_customers),
168178
).edge_matrix
169179

170-
location_indices = indices[epoch_customers]
180+
location_indices = customer_index[epoch_customers]
171181
new_coordinates = env.instance.static_instance.coordinate[location_indices]
172182
new_start_time = start_time[epoch_customers]
173183
new_service_time = service_time[epoch_customers]
@@ -184,7 +194,7 @@ function anticipative_solver(
184194
epoch_duration = env.instance.epoch_duration
185195
Δ_dispatch = env.instance.Δ_dispatch
186196
planning_start_time = (epoch - 1) * epoch_duration + Δ_dispatch
187-
if epoch == last_epoch
197+
if epoch == end_epoch
188198
# If we are in the last epoch, all requests must be dispatched
189199
is_must_dispatch[2:end] .= true
190200
else
@@ -193,6 +203,7 @@ function anticipative_solver(
193203
new_start_time[2:end]
194204
end
195205
is_postponable[2:end] .= .!is_must_dispatch[2:end]
206+
# TODO: avoid code duplication with add_new_customers!
196207

197208
state = DVSPState(;
198209
state_instance=static_instance,

src/DynamicVehicleScheduling/plot.jl

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -192,49 +192,7 @@ Plot a given DVSPState with routes overlaid. This version accepts routes as a Bi
192192
where entry (i,j) = true indicates an edge from location i to location j.
193193
"""
194194
function plot_routes(state::DVSPState, routes::BitMatrix; kwargs...)
195-
# Convert BitMatrix to vector of route vectors
196-
n_locations = size(routes, 1)
197-
route_vectors = Vector{Int}[]
198-
199-
# Find all outgoing edges from depot (location 1)
200-
depot_destinations = findall(routes[1, :])
201-
202-
# For each destination from depot, reconstruct the route
203-
for dest in depot_destinations
204-
if dest != 1 # Skip self-loops at depot
205-
route = Int[]
206-
current = dest
207-
push!(route, current)
208-
209-
# Follow the route until we return to depot
210-
while true
211-
# Find next location (should be unique for valid routes)
212-
next_locations = findall(routes[current, :])
213-
214-
# Filter out the depot for intermediate steps
215-
non_depot_next = filter(x -> x != 1, next_locations)
216-
217-
if isempty(non_depot_next)
218-
# Must return to depot, route is complete
219-
break
220-
elseif length(non_depot_next) == 1
221-
# Continue to next location
222-
current = non_depot_next[1]
223-
push!(route, current)
224-
else
225-
# Multiple outgoing edges - this shouldn't happen in valid routes
226-
# but we'll take the first one
227-
current = non_depot_next[1]
228-
push!(route, current)
229-
end
230-
end
231-
232-
if !isempty(route)
233-
push!(route_vectors, route)
234-
end
235-
end
236-
end
237-
195+
route_vectors = decode_bitmatrix_to_routes(routes)
238196
return plot_routes(state, route_vectors; kwargs...)
239197
end
240198

src/DynamicVehicleScheduling/state.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ function is_feasible(state::DVSPState, routes::Vector{Vector{Int}}; verbose::Boo
149149
if all(is_dispatched[is_must_dispatch])
150150
return true
151151
else
152-
verbose && @warn "Not all must-dispatch requests are dispatched"
152+
verbose &&
153+
@warn "Not all must-dispatch requests are dispatched $(is_dispatched[is_must_dispatch])"
153154
return false
154155
end
155156
end
@@ -180,6 +181,58 @@ function apply_routes!(
180181
return c
181182
end
182183

184+
function decode_bitmatrix_to_routes(routes::BitMatrix)
185+
# Convert BitMatrix to vector of route vectors
186+
n_locations = size(routes, 1)
187+
route_vectors = Vector{Int}[]
188+
189+
# Find all outgoing edges from depot (location 1)
190+
depot_destinations = findall(routes[1, :])
191+
192+
# For each destination from depot, reconstruct the route
193+
for dest in depot_destinations
194+
if dest != 1 # Skip self-loops at depot
195+
route = Int[]
196+
current = dest
197+
push!(route, current)
198+
199+
# Follow the route until we return to depot
200+
while true
201+
# Find next location (should be unique for valid routes)
202+
next_locations = findall(routes[current, :])
203+
204+
# Filter out the depot for intermediate steps
205+
non_depot_next = filter(x -> x != 1, next_locations)
206+
207+
if isempty(non_depot_next)
208+
# Must return to depot, route is complete
209+
break
210+
elseif length(non_depot_next) == 1
211+
# Continue to next location
212+
current = non_depot_next[1]
213+
push!(route, current)
214+
else
215+
throw(
216+
ErrorException(
217+
"Invalid route: multiple outgoing edges from location $current"
218+
),
219+
)
220+
end
221+
end
222+
223+
if !isempty(route)
224+
push!(route_vectors, route)
225+
end
226+
end
227+
end
228+
return route_vectors
229+
end
230+
231+
function apply_routes!(state::DVSPState, routes::BitMatrix; check_feasibility::Bool=true)
232+
route_vectors = decode_bitmatrix_to_routes(routes)
233+
return apply_routes!(state, route_vectors; check_feasibility)
234+
end
235+
183236
function cost(state::DVSPState, routes::Vector{Vector{Int}})
184237
return cost(routes, duration(state.state_instance))
185238
end

src/Utils/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ Generate a vector of environments for the given dynamic benchmark and dataset.
232232
"""
233233
function generate_environments(
234234
bench::AbstractDynamicBenchmark,
235-
dataset::Vector{<:DataSample};
235+
dataset::AbstractArray{<:DataSample};
236236
seed=nothing,
237237
rng=MersenneTwister(seed),
238238
kwargs...,

test/dynamic_vsp.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,9 @@
4646
y2 = maximizer(θ2; instance=instance2)
4747
@test size(x, 1) == 2
4848
@test size(x2, 1) == 14
49+
50+
anticipative_value, solution = generate_anticipative_solution(b, env; reset_env=true)
51+
reset!(env; reset_rng=true)
52+
cost = sum(step!(env, sample.y_true) for sample in solution)
53+
@test isapprox(cost, anticipative_value; atol=1e-5)
4954
end

0 commit comments

Comments
 (0)