Skip to content

Commit ab6f9ef

Browse files
committed
fix doc
1 parent ecb4054 commit ab6f9ef

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

docs/src/tutorials/warcraft.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@ b = WarcraftBenchmark()
1717
# This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:
1818
dataset = generate_dataset(b, 50)
1919

20-
# We obtain an [`InferOptDataset`](@ref) object, which contains all needed data for the problem.
20+
# We obtain a vector of [`DataSample`](@ref) object, which contains all needed data for the problem.
2121
# Subdatasets can be created through regular slicing:
2222
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
2323

2424
# And getting an individual sample will return a NamedTuple with four keys: `features`, `instance`, `costs`, and `solution`:
2525
sample = test_dataset[1]
2626
# In the case the the Warcraft benchmark, `features` correspond to the input image:
27-
x = sample.features
27+
x = sample.x
2828
# `costs` correspond to the true unknown terrain weights:
29-
θ_true = sample.costs
29+
θ_true = sample.θ
3030
# `solution` correspond to the true shortest path:
31-
y_true = sample.solution
31+
y_true = sample.y
3232
# `instance` is not used in this benchmark, therefore set to nothing:
3333
sample.instance
3434

3535
# For some benchmarks, we provide the following plotting method [`plot_data`](@ref) to visualize the data:
36-
plot_data(b; sample...)
36+
plot_data(b, sample)
3737
# We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
3838

3939
# ## Building a pipeline
@@ -50,7 +50,7 @@ maximizer = generate_maximizer(b; dijkstra=true)
5050
# In the case o fthe Warcraft benchmark, the method has an additioonal keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
5151
y = maximizer(θ)
5252
# As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
53-
plot_data(b; features=x, costs=θ, solution=y)
53+
plot_data(b, DataSample(; x, θ, y))
5454
# We can evaluate the current pipeline performance using the optimality gap metric:
5555
starting_gap = compute_gap(b, test_dataset, model, maximizer)
5656

@@ -61,9 +61,6 @@ using InferOpt
6161
using Flux
6262
using Plots
6363

64-
X_train = train_dataset.features
65-
Y_train = train_dataset.solutions
66-
6764
perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
6865
loss = FenchelYoungLoss(perturbed_maximizer)
6966

@@ -73,7 +70,7 @@ opt_state = Flux.setup(Adam(1e-3), model)
7370
loss_history = Float64[]
7471
for epoch in 1:50
7572
val, grads = Flux.withgradient(model) do m
76-
sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset)
73+
sum(loss(m(sample.x), sample.y) for sample in train_dataset) / length(train_dataset)
7774
end
7875
Flux.update!(opt_state, model, grads[1])
7976
push!(loss_history, val)
@@ -86,4 +83,6 @@ plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
8683
final_gap = compute_gap(b, test_dataset, model, maximizer)
8784

8885
#
89-
plot_data(b; sample...)
86+
θ = model(x)
87+
y = maximizer(θ)
88+
plot_data(b, DataSample(; x, θ, y))

src/Utils/interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ The following methods exist for benchmarks:
1414
abstract type AbstractBenchmark end
1515

1616
"""
17-
generate_dataset(::AbstractBenchmark, dataset_size::Int) -> InferOptDataset
17+
generate_dataset(::AbstractBenchmark, dataset_size::Int) -> Vector{<:DataSample}
1818
19-
Generate an [`InferOptDataset`](@ref) for given benchmark as a Vector of length `dataset_size`.
20-
Content of the dataset can be visualized using [`plot_data`](@ref).
19+
Generate a `Vector` of [`DataSample`](@ref) of length `dataset_size` for given benchmark.
20+
Content of the dataset can be visualized using [`plot_data`](@ref), when it applies.
2121
"""
2222
function generate_dataset end
2323

@@ -46,7 +46,7 @@ Check the specific benchmark documentation of `plot_data` for more details on th
4646
function plot_data end
4747

4848
"""
49-
compute_gap(::AbstractBenchmark, dataset::InferOptDataset, statistical_model, maximizer) -> Float64
49+
compute_gap(::AbstractBenchmark, dataset::Vector{<:DataSample}, statistical_model, maximizer) -> Float64
5050
5151
Compute the average relative optimality gap of the pipeline on the dataset.
5252
"""

src/Warcraft/Warcraft.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,14 @@ $TYPEDSIGNATURES
2828
Plot the image `im`, the weights `weights`, and the path `path` on the same Figure.
2929
"""
3030
function Utils.plot_data(
31-
::WarcraftBenchmark;
32-
features,
33-
solution,
34-
costs,
31+
::WarcraftBenchmark,
32+
sample::DataSample,
3533
θ_title="Weights",
3634
y_title="Path",
37-
θ_true=costs,
35+
θ_true=sample.θ,
3836
kwargs...,
3937
)
40-
x = features
41-
y = solution
42-
θ = costs
38+
(; x, y, θ) = sample
4339
im = dropdims(x; dims=4)
4440
img = convert_image_for_plot(im)
4541
p1 = Plots.plot(

0 commit comments

Comments
 (0)