Skip to content

Commit 987fec8

Browse files
committed
Implement policy generator for DVSP, and cleanup seed handling
1 parent e8be496 commit 987fec8

File tree

6 files changed

+88
-47
lines changed

6 files changed

+88
-47
lines changed

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("algorithms/anticipative_solver.jl")
3838
include("learning/features.jl")
3939
include("learning/2d_features.jl")
4040

41+
include("policy.jl")
4142
# include("policy/abstract_vsp_policy.jl")
4243
# include("policy/greedy_policy.jl")
4344
# include("policy/lazy_policy.jl")
@@ -104,6 +105,20 @@ function Utils.generate_anticipative_solution(
104105
)
105106
end
106107

108+
function Utils.generate_policies(b::DynamicVehicleSchedulingBenchmark)
109+
lazy = Policy(
110+
"Lazy",
111+
"Lazy policy that dispatches vehicles only when they are ready.",
112+
lazy_policy,
113+
)
114+
greedy = Policy(
115+
"Greedy",
116+
"Greedy policy that dispatches vehicles to the nearest customer.",
117+
greedy_policy,
118+
)
119+
return (lazy, greedy)
120+
end
121+
107122
export DynamicVehicleSchedulingBenchmark
108123

109124
end

src/DynamicVehicleScheduling/environment/environment.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
1-
struct DVSPEnv{S<:DVSPState} <: Utils.AbstractEnvironment
1+
mutable struct DVSPEnv{S<:DVSPState,R<:AbstractRNG,SS} <: Utils.AbstractEnvironment
22
"associated instance"
33
instance::Instance
44
"current state"
55
state::S
66
"scenario the environment will use when not given a specific one"
77
scenario::Scenario
8+
"random number generator"
9+
rng::R
10+
"seed for the environment"
11+
seed::SS
812
end
913

1014
"""
1115
$TYPEDSIGNATURES
1216
1317
Constructor for [`DVSPEnv`](@ref).
1418
"""
15-
function DVSPEnv(instance::Instance; seed=nothing, rng=MersenneTwister(seed))
16-
scenario = Utils.generate_scenario(instance; rng, seed)
19+
function DVSPEnv(instance::Instance; seed=nothing)
20+
rng = MersenneTwister(seed)
21+
scenario = Utils.generate_scenario(instance; rng)
1722
initial_state = DVSPState(instance; scenario[1]...)
18-
return DVSPEnv(instance, initial_state, scenario)
23+
return DVSPEnv(instance, initial_state, scenario, rng, seed)
1924
end
2025

2126
currrent_epoch(env::DVSPEnv) = current_epoch(env.state)
2227
epoch_duration(env::DVSPEnv) = epoch_duration(env.instance)
2328
last_epoch(env::DVSPEnv) = last_epoch(env.instance)
2429
Δ_dispatch(env::DVSPEnv) = Δ_dispatch(env.instance)
2530

31+
Utils.get_seed(env::DVSPEnv) = env.seed
32+
2633
"""
2734
$TYPEDSIGNATURES
2835
@@ -59,13 +66,19 @@ $TYPEDSIGNATURES
5966
Reset the environment to its initial state.
6067
Also reset the seed if `reset_seed` is set to true.
6168
"""
62-
function Utils.reset!(env::DVSPEnv, scenario=env.scenario)
63-
reset_state!(env.state, env.instance; scenario[1]...)
69+
function Utils.reset!(env::DVSPEnv; seed=get_seed(env), reset_seed=false)
70+
if reset_seed
71+
Random.seed!(env.rng, seed)
72+
end
73+
env.scenario = Utils.generate_scenario(env; rng=env.rng)
74+
reset_state!(env.state, env.instance; env.scenario[1]...)
6475
return nothing
6576
end
6677

6778
"""
68-
remove dispatched customers, advance time, and add new requests to the environment.
79+
$TYPEDSIGNATURES
80+
81+
Remove dispatched customers, advance time, and add new requests to the environment.
6982
"""
7083
function Utils.step!(env::DVSPEnv, routes, scenario=env.scenario)
7184
reward = -apply_routes!(env.state, routes)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function greedy_policy(env::DVSPEnv; model_builder=highs_model)
2+
_, state = observe(env)
3+
(; is_postponable) = state
4+
nb_postponable_requests = sum(is_postponable)
5+
θ = ones(nb_postponable_requests) * 1e9
6+
routes = prize_collecting_vsp(θ; instance=state, model_builder)
7+
return routes
8+
end
9+
10+
function lazy_policy(env::DVSPEnv; model_builder=highs_model)
11+
_, state = observe(env)
12+
nb_postponable_requests = sum(state.is_postponable)
13+
θ = ones(nb_postponable_requests) * -1e9
14+
routes = prize_collecting_vsp(θ; instance=state, model_builder)
15+
return routes
16+
end
17+
18+
"""
19+
$TYPEDEF
20+
21+
Kleopatra policy for the Dynamic Vehicle Scheduling Problem.
22+
"""
23+
struct KleopatraVSPPolicy{P}
24+
prize_predictor::P
25+
end
26+
27+
function::KleopatraVSPPolicy)(env::DVSPEnv; model_builder=highs_model)
28+
x, state = observe(env)
29+
(; prize_predictor) = π
30+
# x = has_2D_features ? compute_2D_features(env) : compute_features(env)
31+
θ = prize_predictor(x)
32+
routes = prize_collecting_vsp(θ; instance=state, model_builder)
33+
return routes
34+
end

src/DynamicVehicleScheduling/policy/greedy_policy.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ function (π::GreedyVSPPolicy)(env::DVSPEnv; model_builder=highs_model)
1414
routes = prize_collecting_vsp(θ; instance=state, model_builder)
1515
return routes
1616
end
17+
18+
function greedy_policy(env::DVSPEnv; model_builder=highs_model)
19+
_, state = observe(env)
20+
(; is_postponable) = state
21+
nb_postponable_requests = sum(is_postponable)
22+
θ = ones(nb_postponable_requests) * 1e9
23+
routes = prize_collecting_vsp(θ; instance=state, model_builder)
24+
return routes
25+
end
Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +0,0 @@
1-
"""
2-
$TYPEDEF
3-
4-
Kleopatra policy for the Dynamic Vehicle Scheduling Problem.
5-
"""
6-
struct KleopatraVSPPolicy{P} <: AbstractDynamicVSPPolicy
7-
prize_predictor::P
8-
has_2D_features::Bool
9-
end
10-
11-
"""
12-
$TYPEDSIGNATURES
13-
14-
Custom constructor for [`KleopatraVSPPolicy`](@ref).
15-
"""
16-
function KleopatraVSPPolicy(prize_predictor; has_2D_features=nothing)
17-
has_2D_features = if isnothing(has_2D_features)
18-
size(prize_predictor[1].weight, 2) == 2
19-
else
20-
has_2D_features
21-
end
22-
return KleopatraVSPPolicy(prize_predictor, has_2D_features)
23-
end
24-
25-
function::KleopatraVSPPolicy)(env::DVSPEnv; model_builder=highs_model)
26-
state = observe(env)
27-
(; prize_predictor, has_2D_features) = π
28-
x = has_2D_features ? compute_2D_features(env) : compute_features(env)
29-
θ = prize_predictor(x)
30-
routes = prize_collecting_vsp(θ; instance=state, model_builder)
31-
return routes
32-
end

src/Utils/policy.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ $TYPEDSIGNATURES
3131
Run the policy on the environment and return the total reward and a dataset of observations.
3232
By default, the environment is reset before running the policy.
3333
"""
34-
function run_policy!(policy, env::AbstractEnvironment)
34+
function run_policy!(policy, env::AbstractEnvironment; kwargs...)
3535
total_reward = 0.0
3636
reset!(env; reset_seed=false)
3737
local labeled_dataset
3838
while !is_terminated(env)
39-
y = policy(env)
39+
y = policy(env; kwargs...)
4040
features, state = observe(env)
4141
if @isdefined labeled_dataset
4242
push!(labeled_dataset, DataSample(; x=features, y_true=y, instance=state))
@@ -49,33 +49,35 @@ function run_policy!(policy, env::AbstractEnvironment)
4949
return total_reward, labeled_dataset
5050
end
5151

52-
function run_policy!(policy, envs::Vector{<:AbstractEnvironment})
52+
function run_policy!(policy, envs::Vector{<:AbstractEnvironment}; kwargs...)
5353
E = length(envs)
5454
rewards = zeros(Float64, E)
5555
datasets = map(1:E) do e
56-
reward, dataset = run_policy!(policy, envs[e])
56+
reward, dataset = run_policy!(policy, envs[e]; kwargs...)
5757
rewards[e] = reward
5858
return dataset
5959
end
6060
return rewards, vcat(datasets...)
6161
end
6262

63-
function run_policy!(policy, env::AbstractEnvironment, episodes::Int; seed=get_seed(env))
63+
function run_policy!(
64+
policy, env::AbstractEnvironment, episodes::Int; seed=get_seed(env), kwargs...
65+
)
6466
reset!(env; reset_seed=true, seed)
6567
total_reward = 0.0
6668
datasets = map(1:episodes) do _i
67-
reward, dataset = run_policy!(policy, env)
69+
reward, dataset = run_policy!(policy, env; kwargs...)
6870
total_reward += reward
6971
return dataset
7072
end
7173
return total_reward / episodes, vcat(datasets...)
7274
end
7375

74-
function run_policy!(policy, envs::Vector{<:AbstractEnvironment}, episodes::Int)
76+
function run_policy!(policy, envs::Vector{<:AbstractEnvironment}, episodes::Int; kwargs...)
7577
E = length(envs)
7678
rewards = zeros(Float64, E)
7779
datasets = map(1:E) do e
78-
reward, dataset = run_policy!(policy, envs[e], episodes)
80+
reward, dataset = run_policy!(policy, envs[e], episodes; kwargs...)
7981
rewards[e] = reward
8082
return dataset
8183
end

0 commit comments

Comments
 (0)