Skip to content

Commit 0ae5737

Browse files
committed
bump to newer version of DFLBenchmarks
1 parent 83e4045 commit 0ae5737

File tree

4 files changed

+43
-19
lines changed

4 files changed

+43
-19
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options
21
style = "blue"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1212
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1313

1414
[compat]
15+
DecisionFocusedLearningBenchmarks = "0.3.0"
1516
Flux = "0.16.5"
1617
InferOpt = "0.7.1"
1718
MLUtils = "0.4.8"

scripts/main.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ using DecisionFocusedLearningAlgorithms
22
using DecisionFocusedLearningBenchmarks
33
using MLUtils
44
using Statistics
5+
using Plots
6+
7+
res = fyl_train_model(ArgmaxBenchmark(); epochs=10_000)
8+
plot(res.validation_loss[100:end]; label="Validation Loss")
9+
plot!(res.training_loss[100:end]; label="Training Loss")
10+
11+
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
12+
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
513

614
struct KleopatraPolicy{M}
715
model::M
@@ -13,10 +21,6 @@ function (m::KleopatraPolicy)(env)
1321
return maximizer(θ; instance)
1422
end
1523

16-
fyl_train_model(ArgmaxBenchmark(); epochs=1000)
17-
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
18-
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
19-
2024
b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)
2125
dataset = generate_dataset(b, 100)
2226
train_instances, validation_instances, test_instances = splitobs(

src/fyl.jl

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,75 @@
11
# TODO: every N epochs
22
# TODO: best_model saving method, using default metric validation loss, overwritten in dagger
33
# TODO: Implement validation loss as a metric callback
4+
# TODO: batch training option
5+
# TODO: parallelize loss computation on validation set
46

57
function fyl_train_model!(
68
model,
79
maximizer,
810
train_dataset::AbstractArray{<:DataSample},
911
validation_dataset;
1012
epochs=100,
11-
maximizer_kwargs=(sample -> (; instance=sample.instance)),
13+
maximizer_kwargs=(sample -> (; instance=sample.info)),
1214
metrics_callbacks::NamedTuple=NamedTuple(),
1315
)
14-
perturbed = PerturbedAdditive(maximizer; nb_samples=20, ε=1.0, threaded=true)
16+
perturbed = PerturbedAdditive(maximizer; nb_samples=50, ε=1.0, threaded=true, seed=0)
1517
loss = FenchelYoungLoss(perturbed)
1618

1719
optimizer = Adam()
1820
opt_state = Flux.setup(optimizer, model)
1921

2022
total_loss = 0.0
2123
for sample in validation_dataset
22-
(; x, y_true) = sample
23-
total_loss += loss(model(x), y_true; maximizer_kwargs(sample)...)
24+
(; x, y) = sample
25+
total_loss += loss(model(x), y; maximizer_kwargs(sample)...)
2426
end
2527
loss_history = [total_loss / length(validation_dataset)]
2628

29+
total_train_loss = 0.0
30+
for sample in train_dataset
31+
(; x, y) = sample
32+
total_train_loss += loss(model(x), y; maximizer_kwargs(sample)...)
33+
end
34+
2735
# Initialize metrics history with epoch 0 for type stability
2836
metrics_history = _initialize_nested_metrics(metrics_callbacks, model, maximizer, 0)
2937

3038
# Add validation loss to metrics
3139
metrics_history = merge(
32-
metrics_history, (; validation_loss=[total_loss / length(validation_dataset)])
40+
metrics_history,
41+
(;
42+
validation_loss=[total_loss / length(validation_dataset)],
43+
training_loss=[total_train_loss / length(train_dataset)],
44+
),
3345
)
3446

3547
@showprogress for epoch in 1:epochs
48+
l = 0
3649
for sample in train_dataset
37-
(; x, y_true) = sample
38-
grads = Flux.gradient(model) do m
39-
loss(m(x), y_true; maximizer_kwargs(sample)...)
50+
(; x, y) = sample
51+
val, grads = Flux.withgradient(model) do m
52+
loss(m(x), y; maximizer_kwargs(sample)...)
4053
end
54+
l += val
4155
Flux.update!(opt_state, model, grads[1])
4256
end
4357
# Evaluate on validation set
4458
total_loss = 0.0
4559
for sample in validation_dataset
46-
(; x, y_true) = sample
47-
total_loss += loss(model(x), y_true; maximizer_kwargs(sample)...)
60+
(; x, y) = sample
61+
total_loss += loss(model(x), y; maximizer_kwargs(sample)...)
4862
end
4963
push!(loss_history, total_loss / length(validation_dataset))
5064
push!(metrics_history.validation_loss, total_loss / length(validation_dataset))
65+
# push!(metrics_history.training_loss, l / length(train_dataset))
66+
67+
total_loss = 0.0
68+
for sample in train_dataset
69+
(; x, y) = sample
70+
total_loss += loss(model(x), y; maximizer_kwargs(sample)...)
71+
end
72+
push!(metrics_history.training_loss, total_loss / length(train_dataset))
5173

5274
# Call metrics callbacks
5375
if !isempty(metrics_callbacks)
@@ -64,10 +86,8 @@ function fyl_train_model!(
6486
end
6587

6688
function fyl_train_model(b::AbstractBenchmark; kwargs...)
67-
dataset = generate_dataset(b, 30)
68-
train_dataset, validation_dataset, test_dataset = dataset[2:2],
69-
dataset[11:20],
70-
dataset[21:30]
89+
dataset = generate_dataset(b, 100)
90+
train_dataset, validation_dataset, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
7191
model = generate_statistical_model(b)
7292
maximizer = generate_maximizer(b)
7393
return fyl_train_model!(model, maximizer, train_dataset, validation_dataset; kwargs...)

0 commit comments

Comments
 (0)