Skip to content

Commit edaf383

Browse files
committed
Fix anticipative policy
1 parent 976335d commit edaf383

File tree

3 files changed

+122
-82
lines changed

3 files changed

+122
-82
lines changed

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -48,43 +48,57 @@ function anticipative_solver(
4848
model_builder=highs_model,
4949
two_dimensional_features=env.instance.two_dimensional_features,
5050
reset_env=true,
51-
nb_epochs=typemax(Int),
51+
nb_epochs=nothing,
5252
seed=get_seed(env),
53+
verbose=false,
5354
)
5455
if reset_env
5556
reset!(env; reset_rng=true, seed)
5657
end
5758

59+
@assert !is_terminated(env)
60+
5861
start_epoch = current_epoch(env)
59-
end_epoch = min(last_epoch(env), start_epoch + nb_epochs - 1)
62+
end_epoch = if isnothing(nb_epochs)
63+
last_epoch(env)
64+
else
65+
min(last_epoch(env), start_epoch + nb_epochs - 1)
66+
end
6067
T = start_epoch:end_epoch
68+
TT = (start_epoch + 1):end_epoch # horizon without start epoch
69+
70+
starting_state = deepcopy(env.state)
6171

6272
request_epoch = [0]
63-
for t in T
73+
request_epoch = vcat(request_epoch, fill(start_epoch, customer_count(starting_state)))
74+
for t in TT
6475
request_epoch = vcat(request_epoch, fill(t, length(scenario.indices[t])))
6576
end
66-
customer_index = vcat(1, scenario.indices[T]...)
67-
service_time = vcat(0.0, scenario.service_time[T]...)
68-
start_time = vcat(0.0, scenario.start_time[T]...)
77+
78+
customer_index = vcat(starting_state.location_indices, scenario.indices[TT]...)
79+
service_time = vcat(
80+
starting_state.state_instance.service_time, scenario.service_time[TT]...
81+
)
82+
start_time = vcat(starting_state.state_instance.start_time, scenario.start_time[TT]...)
6983

7084
duration = env.instance.static_instance.duration[customer_index, customer_index]
7185
(; epoch_duration, Δ_dispatch) = env.instance
7286

7387
model = model_builder()
74-
set_silent(model)
88+
verbose || set_silent(model)
7589

7690
nb_nodes = length(customer_index)
7791
job_indices = 2:nb_nodes
7892
epoch_indices = T
7993

80-
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true)
94+
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
8195

8296
@objective(
8397
model,
8498
Max,
8599
sum(
86-
-duration[i, j] * y[i, j, t] for
87-
i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices
100+
-duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes,
101+
t in epoch_indices
88102
)
89103
)
90104

@@ -136,38 +150,55 @@ function anticipative_solver(
136150
value.(y), env, customer_index, epoch_indices
137151
)
138152

139-
epoch_indices = Vector{Int}[]
140-
N = 1
141-
indices = [1]
142153
index = 1
143-
for epoch in 1:last_epoch(env)
144-
M = length(scenario.indices[epoch])
145-
indices = vcat(indices, (N + 1):(N + M))
146-
push!(epoch_indices, copy(indices))
154+
indices = collect(1:(customer_count(starting_state) + 1)) # current known indices in global indexing
155+
epoch_indices = [indices] # store global indices present at each epoch
156+
N = length(indices) # current last index known in global indexing
157+
for epoch in TT # 1:last_epoch(env)
158+
# remove dispatched customers from indices
159+
dispatched = vcat(epoch_routes[index]...)
160+
indices = setdiff(indices, dispatched)
161+
162+
M = length(scenario.indices[epoch]) # number of new customers in epoch
163+
indices = vcat(indices, (N + 1):(N + M)) # add global indices of customers in epoch
164+
push!(epoch_indices, copy(indices)) # store global indices present at each epoch
147165
N = N + M
148-
if epoch in T
149-
dispatched = vcat(epoch_routes[index]...)
150-
index += 1
151-
indices = setdiff(indices, dispatched)
152-
end
166+
index += 1
153167
end
154-
155-
indices = vcat(1, scenario.indices...)
156-
start_time = vcat(0.0, scenario.start_time...)
157-
service_time = vcat(0.0, scenario.service_time...)
168+
# epoch_indices = Vector{Int}[] # store global indices present at each epoch
169+
# N = 1 # current last index known in global indexing (= depot)
170+
# index = 1
171+
# indices = [1]
172+
# for epoch in 1:last_epoch(env)
173+
# M = length(scenario.indices[epoch]) # number of new customers in epoch
174+
# indices = vcat(indices, (N + 1):(N + M)) # add global indices of customers in epoch
175+
# push!(epoch_indices, copy(indices))
176+
# N = N + M
177+
# if epoch in T #
178+
# dispatched = vcat(epoch_routes[index]...)
179+
# index += 1
180+
# indices = setdiff(indices, dispatched)
181+
# end
182+
# end
183+
184+
# indices = vcat(1, scenario.indices...)
185+
# start_time = vcat(0.0, scenario.start_time...)
186+
# service_time = vcat(0.0, scenario.service_time...)
158187

159188
dataset = map(enumerate(T)) do (i, epoch)
160189
routes = epoch_routes[i]
161-
epoch_customers = epoch_indices[epoch]
162-
163-
y_true = VSPSolution(
164-
Vector{Int}[
165-
map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes
166-
];
167-
max_index=length(epoch_customers),
168-
).edge_matrix
169-
170-
location_indices = indices[epoch_customers]
190+
epoch_customers = epoch_indices[i]
191+
192+
y_true =
193+
VSPSolution(
194+
Vector{Int}[
195+
map(idx -> findfirst(==(idx), epoch_customers), route) for
196+
route in routes
197+
];
198+
max_index=length(epoch_customers),
199+
).edge_matrix
200+
201+
location_indices = customer_index[epoch_customers]
171202
new_coordinates = env.instance.static_instance.coordinate[location_indices]
172203
new_start_time = start_time[epoch_customers]
173204
new_service_time = service_time[epoch_customers]
@@ -189,8 +220,7 @@ function anticipative_solver(
189220
is_must_dispatch[2:end] .= true
190221
else
191222
is_must_dispatch[2:end] .=
192-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
193-
new_start_time[2:end]
223+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
194224
end
195225
is_postponable[2:end] .= .!is_must_dispatch[2:end]
196226

src/DynamicVehicleScheduling/plot.jl

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -192,49 +192,7 @@ Plot a given DVSPState with routes overlaid. This version accepts routes as a Bi
192192
where entry (i,j) = true indicates an edge from location i to location j.
193193
"""
194194
function plot_routes(state::DVSPState, routes::BitMatrix; kwargs...)
195-
# Convert BitMatrix to vector of route vectors
196-
n_locations = size(routes, 1)
197-
route_vectors = Vector{Int}[]
198-
199-
# Find all outgoing edges from depot (location 1)
200-
depot_destinations = findall(routes[1, :])
201-
202-
# For each destination from depot, reconstruct the route
203-
for dest in depot_destinations
204-
if dest != 1 # Skip self-loops at depot
205-
route = Int[]
206-
current = dest
207-
push!(route, current)
208-
209-
# Follow the route until we return to depot
210-
while true
211-
# Find next location (should be unique for valid routes)
212-
next_locations = findall(routes[current, :])
213-
214-
# Filter out the depot for intermediate steps
215-
non_depot_next = filter(x -> x != 1, next_locations)
216-
217-
if isempty(non_depot_next)
218-
# Must return to depot, route is complete
219-
break
220-
elseif length(non_depot_next) == 1
221-
# Continue to next location
222-
current = non_depot_next[1]
223-
push!(route, current)
224-
else
225-
# Multiple outgoing edges - this shouldn't happen in valid routes
226-
# but we'll take the first one
227-
current = non_depot_next[1]
228-
push!(route, current)
229-
end
230-
end
231-
232-
if !isempty(route)
233-
push!(route_vectors, route)
234-
end
235-
end
236-
end
237-
195+
route_vectors = decode_bitmatrix_to_routes(routes)
238196
return plot_routes(state, route_vectors; kwargs...)
239197
end
240198

src/DynamicVehicleScheduling/state.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ function is_feasible(state::DVSPState, routes::Vector{Vector{Int}}; verbose::Boo
149149
if all(is_dispatched[is_must_dispatch])
150150
return true
151151
else
152-
verbose && @warn "Not all must-dispatch requests are dispatched"
152+
verbose &&
153+
@warn "Not all must-dispatch requests are dispatched $(is_dispatched[is_must_dispatch])"
153154
return false
154155
end
155156
end
@@ -180,6 +181,57 @@ function apply_routes!(
180181
return c
181182
end
182183

184+
function decode_bitmatrix_to_routes(routes::BitMatrix)
185+
# Convert BitMatrix to vector of route vectors
186+
n_locations = size(routes, 1)
187+
route_vectors = Vector{Int}[]
188+
189+
# Find all outgoing edges from depot (location 1)
190+
depot_destinations = findall(routes[1, :])
191+
192+
# For each destination from depot, reconstruct the route
193+
for dest in depot_destinations
194+
if dest != 1 # Skip self-loops at depot
195+
route = Int[]
196+
current = dest
197+
push!(route, current)
198+
199+
# Follow the route until we return to depot
200+
while true
201+
# Find next location (should be unique for valid routes)
202+
next_locations = findall(routes[current, :])
203+
204+
# Filter out the depot for intermediate steps
205+
non_depot_next = filter(x -> x != 1, next_locations)
206+
207+
if isempty(non_depot_next)
208+
# Must return to depot, route is complete
209+
break
210+
elseif length(non_depot_next) == 1
211+
# Continue to next location
212+
current = non_depot_next[1]
213+
push!(route, current)
214+
else
215+
# Multiple outgoing edges - this shouldn't happen in valid routes
216+
# but we'll take the first one
217+
current = non_depot_next[1]
218+
push!(route, current)
219+
end
220+
end
221+
222+
if !isempty(route)
223+
push!(route_vectors, route)
224+
end
225+
end
226+
end
227+
return route_vectors
228+
end
229+
230+
function apply_routes!(state::DVSPState, routes::BitMatrix; check_feasibility::Bool=true)
231+
route_vectors = decode_bitmatrix_to_routes(routes)
232+
return apply_routes!(state, route_vectors; check_feasibility)
233+
end
234+
183235
function cost(state::DVSPState, routes::Vector{Vector{Int}})
184236
return cost(routes, duration(state.state_instance))
185237
end

0 commit comments

Comments
 (0)