Skip to content

Commit ed01c45

Browse files
committed
Improve anticipative policy
1 parent 987fec8 commit ed01c45

File tree

8 files changed

+108
-50
lines changed

8 files changed

+108
-50
lines changed

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ $TYPEDFIELDS
6767
end
6868

6969
function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_size::Int=1)
70-
(; max_requests_per_epoch, Δ_dispatch, epoch_duration) = b
70+
(; max_requests_per_epoch, Δ_dispatch, epoch_duration, two_dimensional_features) = b
7171
files = readdir(datadep"dvrptw"; join=true)
7272
dataset_size = min(dataset_size, length(files))
7373
return [
@@ -77,6 +77,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
7777
max_requests_per_epoch,
7878
Δ_dispatch,
7979
epoch_duration,
80+
two_dimensional_features,
8081
),
8182
) for i in 1:dataset_size
8283
]

src/DynamicVehicleScheduling/algorithms/anticipative_solver.jl

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@ $TYPEDSIGNATURES
44
Retrieve anticipative routes solution from the given MIP solution `y`.
55
Outputs a set of routes per epoch.
66
"""
7-
function retrieve_routes_anticipative(y::AbstractArray, dvspenv::DVSPEnv, customer_index)
7+
function retrieve_routes_anticipative(
8+
y::AbstractArray, dvspenv::DVSPEnv, customer_index, epoch_indices
9+
)
810
nb_tasks = length(customer_index)
9-
first_epoch = 1
10-
(; last_epoch) = dvspenv.instance
11+
# first_epoch = 1
12+
# (; last_epoch) = dvspenv.instance
1113
job_indices = 2:(nb_tasks)
12-
epoch_indices = first_epoch:last_epoch
14+
# epoch_indices = first_epoch:last_epoch
1315

1416
routes = [Vector{Int}[] for _ in epoch_indices]
15-
for t in epoch_indices
17+
for (i, t) in enumerate(epoch_indices)
1618
start = [i for i in job_indices if y[1, i, t] 1]
1719
for task in start
1820
route = Int[]
@@ -28,7 +30,7 @@ function retrieve_routes_anticipative(y::AbstractArray, dvspenv::DVSPEnv, custom
2830
end
2931
current_task = next_task
3032
end
31-
push!(routes[t], route)
33+
push!(routes[i], route)
3234
end
3335
end
3436
return routes
@@ -44,28 +46,33 @@ function anticipative_solver(
4446
env::DVSPEnv,
4547
scenario=env.scenario;
4648
model_builder=highs_model,
47-
reset_env=false,
48-
two_dimensional_features=false,
49+
two_dimensional_features=env.instance.two_dimensional_features,
50+
reset_env=true,
51+
nb_epochs=typemax(Int),
4952
)
50-
reset_env && reset!(env)
53+
reset_env && reset!(env; reset_seed=true)
54+
55+
start_epoch = current_epoch(env)
56+
end_epoch = min(last_epoch(env), start_epoch + nb_epochs - 1)
57+
T = start_epoch:end_epoch
58+
5159
request_epoch = [0]
52-
for (epoch, indices) in enumerate(scenario.indices)
53-
request_epoch = vcat(request_epoch, fill(epoch, length(indices)))
60+
for t in T
61+
request_epoch = vcat(request_epoch, fill(t, length(scenario.indices[t])))
5462
end
55-
customer_index = vcat(1, scenario.indices...)
56-
service_time = vcat(0.0, scenario.service_time...)
57-
start_time = vcat(0.0, scenario.start_time...)
63+
customer_index = vcat(1, scenario.indices[T]...)
64+
service_time = vcat(0.0, scenario.service_time[T]...)
65+
start_time = vcat(0.0, scenario.start_time[T]...)
5866

5967
duration = env.instance.static_instance.duration[customer_index, customer_index]
60-
first_epoch = 1
61-
(; last_epoch, epoch_duration, Δ_dispatch) = env.instance
68+
(; epoch_duration, Δ_dispatch) = env.instance
6269

6370
model = model_builder()
6471
set_silent(model)
6572

6673
nb_nodes = length(customer_index)
6774
job_indices = 2:nb_nodes
68-
epoch_indices = first_epoch:last_epoch
75+
epoch_indices = T#first_epoch:last_epoch
6976

7077
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
7178

@@ -102,7 +109,7 @@ function anticipative_solver(
102109

103110
# a trip from i can be done only before limit date
104111
for i in job_indices, t in epoch_indices, j in 1:nb_nodes
105-
if (t - 1) * epoch_duration + duration[1, i] + Δ_dispatch > start_time[i] # ! this only works if first_epoch = 1
112+
if (t - 1) * epoch_duration + duration[1, i] + Δ_dispatch > start_time[i]
106113
@constraint(model, y[i, j, t] <= 0)
107114
end
108115
end
@@ -121,27 +128,32 @@ function anticipative_solver(
121128
optimize!(model)
122129

123130
obj = JuMP.objective_value(model)
124-
epoch_routes = retrieve_routes_anticipative(value.(y), env, customer_index)
131+
epoch_routes = retrieve_routes_anticipative(
132+
value.(y), env, customer_index, epoch_indices
133+
)
125134

126135
epoch_indices = Vector{Int}[]
127136
N = 1
128137
indices = [1]
129-
for epoch in 1:last_epoch
138+
index = 1
139+
for epoch in 1:last_epoch(env)
130140
M = length(scenario.indices[epoch])
131141
indices = vcat(indices, (N + 1):(N + M))
132142
push!(epoch_indices, copy(indices))
133143
N = N + M
134-
epoch_routes[epoch]
135-
dispatched = vcat(epoch_routes[epoch]...)
136-
indices = setdiff(indices, dispatched)
144+
if epoch in T
145+
dispatched = vcat(epoch_routes[index]...)
146+
index += 1
147+
indices = setdiff(indices, dispatched)
148+
end
137149
end
138150

139151
indices = vcat(1, scenario.indices...)
140152
start_time = vcat(0.0, scenario.start_time...)
141153
service_time = vcat(0.0, scenario.service_time...)
142154

143-
dataset = map(1:last_epoch) do epoch
144-
routes = epoch_routes[epoch]
155+
dataset = map(enumerate(T)) do (i, epoch)
156+
routes = epoch_routes[i]
145157
epoch_customers = epoch_indices[epoch]
146158

147159
y_true =
@@ -170,9 +182,13 @@ function anticipative_solver(
170182
epoch_duration = env.instance.epoch_duration
171183
Δ_dispatch = env.instance.Δ_dispatch
172184
planning_start_time = (epoch - 1) * epoch_duration + Δ_dispatch
173-
is_must_dispatch[2:end] .=
174-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
175-
new_start_time[2:end]
185+
if epoch == last_epoch
186+
# If we are in the last epoch, all requests must be dispatched
187+
is_must_dispatch[2:end] .= true
188+
else
189+
is_must_dispatch[2:end] .=
190+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
191+
end
176192
is_postponable[2:end] .= .!is_must_dispatch[2:end]
177193

178194
state = DVSPState(;
@@ -183,7 +199,6 @@ function anticipative_solver(
183199
current_epoch=epoch,
184200
)
185201

186-
# x = compute_2D_features(state, env.instance)
187202
x = if two_dimensional_features
188203
compute_2D_features(state, env.instance)
189204
else
@@ -195,17 +210,3 @@ function anticipative_solver(
195210

196211
return obj, dataset
197212
end
198-
199-
# @kwdef struct AnticipativeSolver
200-
# is_2D::Bool = false
201-
# end
202-
203-
# function (solver::AnticipativeSolver)(env::DVSPEnv, scenario=env.scenario; reset_env=false)
204-
# return generate_anticipative_decision(
205-
# env,
206-
# scenario;
207-
# model_builder=highs_model,
208-
# reset_env,
209-
# two_dimensional_features=solver.is_2D,
210-
# )
211-
# end

src/DynamicVehicleScheduling/environment/environment.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ $TYPEDSIGNATURES
3535
3636
Get the current state of the environment.
3737
"""
38-
Utils.observe(env::DVSPEnv) = nothing, env.state
38+
function Utils.observe(env::DVSPEnv)
39+
if env.instance.two_dimensional_features
40+
return compute_2D_features(env.state, env.instance), env.state
41+
end
42+
# else
43+
return compute_features(env.state, env.instance), env.state
44+
end
3945

4046
current_epoch(env::DVSPEnv) = current_epoch(env.state)
4147

src/DynamicVehicleScheduling/environment/instance.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ Instance data structure for the dynamic vehicle scheduling problem.
1414
epoch_duration::T = 1.0
1515
"last epoch index"
1616
last_epoch::Int
17+
"whether to use two-dimensional features"
18+
two_dimensional_features::Bool = false
1719
end
1820

1921
function Instance(
2022
static_instance::StaticInstance;
2123
max_requests_per_epoch::Int=10,
2224
Δ_dispatch::Float64=1.0,
2325
epoch_duration::Float64=1.0,
26+
two_dimensional_features::Bool=false,
2427
)
2528
last_epoch = trunc(
2629
Int,
@@ -35,6 +38,7 @@ function Instance(
3538
Δ_dispatch=Δ_dispatch,
3639
epoch_duration=epoch_duration,
3740
last_epoch=last_epoch,
41+
two_dimensional_features=two_dimensional_features,
3842
)
3943
end
4044

src/DynamicVehicleScheduling/environment/state.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@ State data structure for the Dynamic Vehicle Scheduling Problem.
1616
is_postponable::BitVector = falses(0)
1717
end
1818

19+
function Base.show(io::IO, state::DVSPState)
20+
return print(
21+
io,
22+
"DVSPState(",
23+
"current_epoch=",
24+
state.current_epoch,
25+
", ",
26+
"location_indices=",
27+
state.location_indices,
28+
", ",
29+
"is_must_dispatch=",
30+
state.is_must_dispatch,
31+
", ",
32+
"is_postponable=",
33+
state.is_postponable,
34+
")",
35+
)
36+
end
37+
1938
function reset_state!(
2039
state::DVSPState, instance::Instance; indices, service_time, start_time
2140
)
@@ -189,9 +208,14 @@ function add_new_customers!(
189208
epoch_duration = instance.epoch_duration
190209
Δ_dispatch = instance.Δ_dispatch
191210
planning_start_time = (state.current_epoch - 1) * epoch_duration + Δ_dispatch
192-
is_must_dispatch[2:end] .=
193-
planning_start_time .+ epoch_duration .+ @view(updated_duration[1, 2:end]) .>
194-
updated_start_time[2:end]
211+
if state.current_epoch == last_epoch(instance)
212+
# If we are in the last epoch, all requests must be dispatched
213+
is_must_dispatch[2:end] .= true
214+
else
215+
is_must_dispatch[2:end] .=
216+
planning_start_time .+ epoch_duration .+ @view(updated_duration[1, 2:end]) .>
217+
updated_start_time[2:end]
218+
end
195219
is_postponable[2:end] .= .!is_must_dispatch[2:end]
196220

197221
state.is_must_dispatch = is_must_dispatch

src/DynamicVehicleScheduling/policy.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ function greedy_policy(env::DVSPEnv; model_builder=highs_model)
44
nb_postponable_requests = sum(is_postponable)
55
θ = ones(nb_postponable_requests) * 1e9
66
routes = prize_collecting_vsp(θ; instance=state, model_builder)
7+
@assert is_feasible(state, routes)
78
return routes
89
end
910

@@ -12,6 +13,7 @@ function lazy_policy(env::DVSPEnv; model_builder=highs_model)
1213
nb_postponable_requests = sum(state.is_postponable)
1314
θ = ones(nb_postponable_requests) * -1e9
1415
routes = prize_collecting_vsp(θ; instance=state, model_builder)
16+
@assert is_feasible(state, routes)
1517
return routes
1618
end
1719

src/Utils/data_sample.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ $TYPEDFIELDS
2222
instance::I = nothing
2323
end
2424

25+
function Base.show(io::IO, d::DataSample)
26+
fields = String[]
27+
if !isnothing(d.x)
28+
push!(fields, "x=$(d.x)")
29+
end
30+
if !isnothing(d.θ_true)
31+
push!(fields, "θ_true=$(d.θ_true)")
32+
end
33+
if !isnothing(d.y_true)
34+
push!(fields, "y_true=$(d.y_true)")
35+
end
36+
if !isnothing(d.instance)
37+
push!(fields, "instance=$(d.instance)")
38+
end
39+
return print(io, "DataSample(", join(fields, ", "), ")")
40+
end
41+
2542
"""
2643
$TYPEDSIGNATURES
2744

src/Utils/policy.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ function run_policy!(policy, env::AbstractEnvironment; kwargs...)
3939
y = policy(env; kwargs...)
4040
features, state = observe(env)
4141
if @isdefined labeled_dataset
42-
push!(labeled_dataset, DataSample(; x=features, y_true=y, instance=state))
42+
push!(
43+
labeled_dataset,
44+
DataSample(; x=features, y_true=y, instance=deepcopy(state)),
45+
)
4346
else
44-
labeled_dataset = [DataSample(; x=features, y_true=y, instance=state)]
47+
labeled_dataset = [DataSample(; x=features, y_true=y, instance=deepcopy(state))]
4548
end
4649
reward = step!(env, y)
4750
total_reward += reward

0 commit comments

Comments
 (0)