Skip to content

Commit 67a0fa9

Browse files
committed
fix tests and cleanup
1 parent 9fe5e86 commit 67a0fa9

File tree

11 files changed

+340
-193
lines changed

11 files changed

+340
-193
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
1515
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
16+
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
1617
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
1718
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1819
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
@@ -40,6 +41,7 @@ Flux = "0.14, 0.15, 0.16"
4041
Graphs = "1.11"
4142
HiGHS = "1.9"
4243
Images = "0.26.1"
44+
InferOpt = "0.7.0"
4345
Ipopt = "1.6"
4446
IterTools = "1.10.0"
4547
JSON = "0.21.4"

docs/src/warcraft.md

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
```@meta
2+
EditURL = "tutorials/warcraft.jl"
3+
```
4+
5+
# Path-finding on image maps
6+
7+
In this tutorial, we showcase DecisionFocusedLearningBenchmarks.jl capabilities on one of its main benchmarks: the Warcraft benchmark.
8+
This benchmark problem is a simple path-finding problem where the goal is to find the shortest path between the top left and bottom right corners of a given image map.
9+
The map is represented as a 2D image representing a 12x12 grid, each cell having an unknown travel cost depending on the terrain type.
10+
11+
First, let's load the package and create a benchmark object as follows:
12+
13+
````@example warcraft
14+
using DecisionFocusedLearningBenchmarks
15+
b = WarcraftBenchmark()
16+
````
17+
18+
## Dataset generation
19+
20+
These benchmark objects behave as generators that can generate various needed elements in order to build an algorithm to tackle the problem.
21+
First of all, all benchmarks are capable of generating datasets as needed, using the [`generate_dataset`](@ref) method.
22+
This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:
23+
24+
````@example warcraft
25+
dataset = generate_dataset(b, 50);
26+
nothing #hide
27+
````
28+
29+
We obtain a vector of [`DataSample`](@ref) objects, containing all needed data for the problem.
30+
Subdatasets can be created through regular slicing:
31+
32+
````@example warcraft
33+
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
34+
````
35+
36+
And getting an individual sample will return a [`DataSample`](@ref) with four fields: `x`, `instance`, `θ`, and `y`:
37+
38+
````@example warcraft
39+
sample = test_dataset[1]
40+
````
41+
42+
`x` correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:
43+
44+
````@example warcraft
45+
x = sample.x
46+
````
47+
48+
`θ_true` correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:
49+
50+
````@example warcraft
51+
θ_true = sample.θ_true
52+
````
53+
54+
`y_true` correspond to the optimal shortest path, encoded as a binary matrix:
55+
56+
````@example warcraft
57+
y_true = sample.y_true
58+
````
59+
60+
`instance` is not used in this benchmark, therefore set to nothing:
61+
62+
````@example warcraft
63+
isnothing(sample.instance)
64+
````
65+
66+
For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
67+
68+
````@example warcraft
69+
plot_data(b, sample)
70+
````
71+
72+
We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
73+
74+
## Building a pipeline
75+
76+
DecisionFocusedLearningBenchmarks also provides methods to build an hybrid machine learning and combinatorial optimization pipeline for the benchmark.
77+
First, the [`generate_statistical_model`](@ref) method generates a machine learning predictor to predict cell weights from the input image:
78+
79+
````@example warcraft
80+
model = generate_statistical_model(b)
81+
````
82+
83+
In the case of the Warcraft benchmark, the model is a convolutional neural network built using the Flux.jl package.
84+
85+
````@example warcraft
86+
θ = model(x)
87+
````
88+
89+
Note that the model is not trained yet, and its parameters are randomly initialized.
90+
91+
Finally, the [`generate_maximizer`](@ref) method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:
92+
93+
````@example warcraft
94+
maximizer = generate_maximizer(b; dijkstra=true)
95+
````
96+
97+
In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
98+
99+
````@example warcraft
100+
y = maximizer(θ)
101+
````
102+
103+
As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
104+
105+
````@example warcraft
106+
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
107+
````
108+
109+
We can evaluate the current pipeline performance using the optimality gap metric:
110+
111+
````@example warcraft
112+
starting_gap = compute_gap(b, test_dataset, model, maximizer)
113+
````
114+
115+
## Using a learning algorithm
116+
117+
We can now train the model using the InferOpt.jl package:
118+
119+
````@example warcraft
120+
using InferOpt
121+
using Flux
122+
using Plots
123+
124+
perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
125+
loss = FenchelYoungLoss(perturbed_maximizer)
126+
127+
starting_gap = compute_gap(b, test_dataset, model, maximizer)
128+
129+
opt_state = Flux.setup(Adam(1e-3), model)
130+
loss_history = Float64[]
131+
for epoch in 1:50
132+
val, grads = Flux.withgradient(model) do m
133+
sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset)
134+
end
135+
Flux.update!(opt_state, model, grads[1])
136+
push!(loss_history, val)
137+
end
138+
139+
plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
140+
````
141+
142+
````@example warcraft
143+
final_gap = compute_gap(b, test_dataset, model, maximizer)
144+
````
145+
146+
````@example warcraft
147+
θ = model(x)
148+
y = maximizer(θ)
149+
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
150+
````
151+
152+
---
153+
154+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
155+

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ using ..Utils
55
using Base: @kwdef
66
using CommonRLInterface: CommonRLInterface, AbstractEnv, reset!, terminated, observe, act!
77
using DataDeps: @datadep_str
8-
# using ChainRulesCore
98
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
109
using Graphs
1110
using HiGHS
12-
# using InferOpt
11+
using InferOpt: LinearMaximizer
1312
using IterTools: partition
1413
using JSON
1514
using JuMP
@@ -21,8 +20,6 @@ using Statistics: mean, quantile
2120

2221
include("utils.jl")
2322

24-
include("abstract_policy.jl")
25-
2623
# static vsp stuff
2724
include("static_vsp/instance.jl")
2825
include("static_vsp/parsing.jl")
@@ -41,20 +38,40 @@ include("algorithms/anticipative_solver.jl")
4138

4239
include("learning/features.jl")
4340
include("learning/2d_features.jl")
44-
include("learning/dataset.jl")
4541

4642
include("policy/abstract_vsp_policy.jl")
4743
include("policy/greedy_policy.jl")
4844
include("policy/lazy_policy.jl")
4945
include("policy/anticipative_policy.jl")
5046
include("policy/kleopatra_policy.jl")
5147

52-
struct DVSPBenchmark <: AbstractDynamicBenchmark end
48+
include("maximizer.jl")
49+
50+
"""
51+
$TYPEDEF
52+
53+
Abstract type for dynamic vehicle scheduling benchmarks.
54+
"""
55+
@kwdef struct DVSPBenchmark <: AbstractDynamicBenchmark
56+
max_requests_per_epoch::Int = 10
57+
Δ_dispatch::Float64 = 1.0
58+
epoch_duration::Float64 = 1.0
59+
end
5360

54-
function Utils.generate_sample(b::DVSPBenchmark, rng::AbstractRNG)
55-
return DataSample(;
56-
instance=Instance(read_vsp_instance(readdir(datadep"dvrptw"; join=true)[1]))
57-
)
61+
function Utils.generate_dataset(b::DVSPBenchmark, dataset_size::Int=1)
62+
(; max_requests_per_epoch, Δ_dispatch, epoch_duration) = b
63+
files = readdir(datadep"dvrptw"; join=true)
64+
dataset_size = min(dataset_size, length(files))
65+
return [
66+
DataSample(;
67+
instance=Instance(
68+
read_vsp_instance(files[i]);
69+
max_requests_per_epoch,
70+
Δ_dispatch,
71+
epoch_duration,
72+
),
73+
) for i in 1:dataset_size
74+
]
5875
end
5976

6077
function Utils.generate_scenario_generator(::DVSPBenchmark)
@@ -70,7 +87,7 @@ function Utils.generate_environment(::DVSPBenchmark, instance::Instance; kwargs.
7087
end
7188

7289
function Utils.generate_maximizer(::DVSPBenchmark)
73-
return prize_collecting_vsp
90+
return LinearMaximizer(oracle; g, h)
7491
end
7592

7693
export DVSPBenchmark #, generate_environment # , generate_sample, generate_anticipative_solver

src/DynamicVehicleScheduling/abstract_policy.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/DynamicVehicleScheduling/environment/environment.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ $TYPEDSIGNATURES
4545
Get the planning start time of the environment, i.e. the time at which vehicles routes dispatched in current epoch can depart.
4646
"""
4747
planning_start_time(env::DVSPEnv) = time(env) + Δ_dispatch(env)
48+
4849
"""
4950
$TYPEDSIGNATURES
5051
5152
Check if the episode is terminated, i.e. if the current epoch is the last one.
5253
"""
53-
CommonRLInterface.terminated(env::DVSPEnv) = current_epoch(env) >= last_epoch(env)
54+
CommonRLInterface.terminated(env::DVSPEnv) = current_epoch(env) > last_epoch(env)
5455

5556
"""
5657
$TYPEDSIGNATURES
@@ -69,7 +70,7 @@ remove dispatched customers, advance time, and add new requests to the environme
6970
function CommonRLInterface.act!(env::DVSPEnv, routes, scenario=env.scenario)
7071
reward = -apply_routes!(env.state, routes)
7172
env.state.current_epoch += 1
72-
if current_epoch(env) <= last_epoch(env)
73+
if !CommonRLInterface.terminated(env)
7374
add_new_customers!(env.state, env.instance; scenario[current_epoch(env)]...)
7475
end
7576
return reward

src/DynamicVehicleScheduling/environment/instance.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ Instance data structure for the dynamic vehicle scheduling problem.
1414
epoch_duration::T = 1.0
1515
"last epoch index"
1616
last_epoch::Int
17-
# "seed for customer sampling"
18-
# seed::S
1917
end
2018

2119
function Instance(
@@ -44,9 +42,3 @@ end
4442
epoch_duration(instance::Instance) = instance.epoch_duration
4543
last_epoch(instance::Instance) = instance.last_epoch
4644
max_requests_per_epoch(instance::Instance) = instance.max_requests_per_epoch
47-
# static_instance(instance::Instance) = instance.static_instance
48-
49-
# duration(instance::Instance) = duration(instance.static_instance)
50-
# service_time(instance::Instance) = service_time(instance.static_instance)
51-
# coordinate(instance::Instance) = coordinate(instance.static_instance)
52-
# start_time(instance::Instance) = start_time(instance.static_instance)

0 commit comments

Comments
 (0)