Skip to content

Commit 4eeda07

Browse files
committed
update
1 parent 0ae5737 commit 4eeda07

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

scripts/main.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ using MLUtils
44
using Statistics
55
using Plots
66

7-
res = fyl_train_model(ArgmaxBenchmark(); epochs=10_000)
8-
plot(res.validation_loss[100:end]; label="Validation Loss")
9-
plot!(res.training_loss[100:end]; label="Training Loss")
7+
res = fyl_train_model(StochasticVehicleSchedulingBenchmark(); epochs=100)
8+
plot(res.validation_loss; label="Validation Loss")
9+
plot!(res.training_loss; label="Training Loss")
1010

1111
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
1212
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))

src/dagger.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ end
6767

6868
function DAgger_train_model(b::AbstractStochasticBenchmark{true}; kwargs...)
6969
dataset = generate_dataset(b, 30)
70-
train_instances, validation_instances, test_instances = dataset[1:10],
71-
dataset[11:20],
72-
dataset[21:30]
70+
train_instances, validation_instances, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
7371
train_environments = generate_environments(b, train_instances; seed=0)
7472
validation_environments = generate_environments(b, validation_instances)
7573
model = generate_statistical_model(b)

src/fyl.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# TODO: Implement validation loss as a metric callback
44
# TODO: batch training option
55
# TODO: parallelize loss computation on validation set
6+
# TODO: have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
7+
# TODO: easier way to define and provide metrics
68

79
function fyl_train_model!(
810
model,
@@ -13,7 +15,7 @@ function fyl_train_model!(
1315
maximizer_kwargs=(sample -> (; instance=sample.info)),
1416
metrics_callbacks::NamedTuple=NamedTuple(),
1517
)
16-
perturbed = PerturbedAdditive(maximizer; nb_samples=50, ε=1.0, threaded=true, seed=0)
18+
perturbed = PerturbedAdditive(maximizer; nb_samples=50, ε=0.0, threaded=true, seed=0)
1719
loss = FenchelYoungLoss(perturbed)
1820

1921
optimizer = Adam()
@@ -86,7 +88,7 @@ function fyl_train_model!(
8688
end
8789

8890
function fyl_train_model(b::AbstractBenchmark; kwargs...)
89-
dataset = generate_dataset(b, 100)
91+
dataset = generate_dataset(b, 20)
9092
train_dataset, validation_dataset, _ = splitobs(dataset; at=(0.3, 0.3, 0.4))
9193
model = generate_statistical_model(b)
9294
maximizer = generate_maximizer(b)

0 commit comments

Comments
 (0)