Skip to content

Commit 400af8e

Browse files
committed
First draft at implementing generic SIL and generic DAgger
1 parent 42ed250 commit 400af8e

File tree

14 files changed

+534
-7
lines changed

14 files changed

+534
-7
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
/Manifest*.toml
55
/docs/Manifest*.toml
66
/docs/build/
7+
tensorboard_logs
8+
.vscode
9+
Manifest.toml

Project.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
name = "DecisionFocusedLearningAlgorithms"
22
uuid = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
33
authors = ["Members of JuliaDecisionFocusedLearning and contributors"]
4-
version = "1.0.0-DEV"
4+
version = "0.0.1"
5+
6+
[deps]
7+
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
8+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
10+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
11+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
12+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
513

614
[compat]
15+
Flux = "0.16.5"
16+
InferOpt = "0.7.1"
17+
MLUtils = "0.4.8"
18+
ProgressMeter = "1.11.0"
19+
UnicodePlots = "3.8.1"
720
julia = "1.11"
821

922
[extras]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
DecisionFocusedLearningAlgorithms = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

docs/make.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
11
using DecisionFocusedLearningAlgorithms
22
using Documenter
33

4-
DocMeta.setdocmeta!(DecisionFocusedLearningAlgorithms, :DocTestSetup, :(using DecisionFocusedLearningAlgorithms); recursive=true)
4+
DocMeta.setdocmeta!(
5+
DecisionFocusedLearningAlgorithms,
6+
:DocTestSetup,
7+
:(using DecisionFocusedLearningAlgorithms);
8+
recursive=true,
9+
)
10+
11+
tutorial_dir = joinpath(@__DIR__, "src", "tutorials")
12+
13+
include_tutorial = true
14+
15+
if include_tutorial
16+
for file in tutorial_files
17+
filepath = joinpath(tutorial_dir, file)
18+
Literate.markdown(filepath, md_dir; documenter=true, execute=false)
19+
end
20+
end
521

622
makedocs(;
723
modules=[DecisionFocusedLearningAlgorithms],
@@ -12,9 +28,7 @@ makedocs(;
1228
edit_link="main",
1329
assets=String[],
1430
),
15-
pages=[
16-
"Home" => "index.md",
17-
],
31+
pages=["Home" => "index.md", "Tutorials" => include_tutorial ? md_tutorial_files : []],
1832
)
1933

2034
deploydocs(;

docs/src/tutorials/tutorial.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Tutorial
2+
using DecisionFocusedLearningAlgorithms
3+
using DecisionFocusedLearningBenchmarks
4+
using MLUtils: splitobs
5+
using Plots
6+
7+
b = ArgmaxBenchmark()
8+
dataset = generate_dataset(b, 100)
9+
train_instances, validation_instances, test_instances = splitobs(
10+
dataset; at=(0.3, 0.3, 0.4)
11+
)
12+
13+
model = generate_statistical_model(b; seed=0)
14+
maximizer = generate_maximizer(b)
15+
16+
compute_gap(b, test_instances, model, maximizer)
17+
18+
metrics_callbacks = (;
19+
:time => (model, maximizer, epoch) -> (epoch_time = time()),
20+
:gap => (;
21+
:val =>
22+
(model, maximizer, epoch) ->
23+
(gap = compute_gap(b, validation_instances, model, maximizer)),
24+
:test =>
25+
(model, maximizer, epoch) ->
26+
(gap = compute_gap(b, test_instances, model, maximizer)),
27+
),
28+
)
29+
30+
fyl_model = deepcopy(model)
31+
log = fyl_train_model!(
32+
fyl_model,
33+
maximizer,
34+
train_instances,
35+
validation_instances;
36+
epochs=100,
37+
metrics_callbacks,
38+
)
39+
40+
log[:gap]
41+
plot(
42+
[log[:gap].val, log[:gap].test];
43+
labels=["Val Gap" "Test Gap"],
44+
xlabel="Epoch",
45+
ylabel="Gap",
46+
)
47+
plot(log[:validation_loss])

scripts/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
DecisionFocusedLearningAlgorithms = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
3+
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
4+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
5+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
6+
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"

scripts/main.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using DecisionFocusedLearningAlgorithms
2+
using DecisionFocusedLearningBenchmarks
3+
using MLUtils
4+
using Statistics
5+
6+
struct KleopatraPolicy{M}
7+
model::M
8+
end
9+
10+
function (m::KleopatraPolicy)(env)
11+
x, instance = observe(env)
12+
θ = m.model(x)
13+
return maximizer(θ; instance)
14+
end
15+
16+
fyl_train_model(ArgmaxBenchmark(); epochs=1000)
17+
baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
18+
DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false))
19+
20+
b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)
21+
dataset = generate_dataset(b, 100)
22+
train_instances, validation_instances, test_instances = splitobs(
23+
dataset; at=(0.3, 0.3, 0.4)
24+
)
25+
train_environments = generate_environments(b, train_instances; seed=0)
26+
validation_environments = generate_environments(b, validation_instances)
27+
test_environments = generate_environments(b, test_instances)
28+
29+
train_dataset = vcat(map(train_environments) do env
30+
v, y = generate_anticipative_solution(b, env; reset_env=true)
31+
return y
32+
end...)
33+
34+
val_dataset = vcat(map(validation_environments) do env
35+
v, y = generate_anticipative_solution(b, env; reset_env=true)
36+
return y
37+
end...)
38+
39+
model = generate_statistical_model(b; seed=0)
40+
maximizer = generate_maximizer(b)
41+
anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
42+
43+
fyl_model = deepcopy(model)
44+
fyl_policy = Policy("fyl", "", KleopatraPolicy(fyl_model))
45+
46+
metrics_callbacks = (;
47+
obj=(model, maximizer, epoch) ->
48+
mean(evaluate_policy!(fyl_policy, test_environments, 1)[1])
49+
)
50+
51+
fyl_loss = fyl_train_model!(
52+
fyl_model, maximizer, train_dataset, val_dataset; epochs=100, metrics_callbacks
53+
)
54+
55+
dagger_model = deepcopy(model)
56+
dagger_policy = Policy("dagger", "", KleopatraPolicy(dagger_model))
57+
metrics_callbacks = (;
58+
obj=(model, maximizer, epoch) ->
59+
mean(evaluate_policy!(dagger_policy, test_environments, 1)[1])
60+
)
61+
dagger_loss = DAgger_train_model!(
62+
dagger_model,
63+
maximizer,
64+
train_environments,
65+
validation_environments,
66+
anticipative_policy;
67+
iterations=10,
68+
fyl_epochs=10,
69+
metrics_callbacks,
70+
)
71+
72+
plot(
73+
0:100,
74+
[fyl_loss.obj[1:end], dagger_loss.obj[1:end]];
75+
labels=["FYL" "DAgger"],
76+
xlabel="Epoch",
77+
ylabel="Test Average Reward (1 scenario)",
78+
)
79+
80+
using Statistics
81+
v_fyl, _ = evaluate_policy!(fyl_policy, test_environments, 100)
82+
v_dagger, _ = evaluate_policy!(dagger_policy, test_environments, 100)
83+
mean(v_fyl)
84+
mean(v_dagger)
85+
86+
anticipative_policy(test_environments[1]; reset_env=true)

scripts/tb.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using TensorBoardLogger, Logging, Random
2+
3+
lg = TBLogger("tensorboard_logs/run"; min_level=Logging.Info)
4+
5+
struct sample_struct
6+
first_field
7+
other_field
8+
end
9+
10+
with_logger(lg) do
11+
for i in 1:100
12+
x0 = 0.5 + i / 30
13+
s0 = 0.5 / (i / 20)
14+
edges = collect(-5:0.1:5)
15+
centers = collect(edges[1:(end - 1)] .+ 0.05)
16+
histvals = [exp(-((c - x0) / s0)^2) for c in centers]
17+
data_tuple = (edges, histvals)
18+
data_struct = sample_struct(i^2, i^1.5 - 0.3 * i)
19+
20+
@info "test" i = i j = i^2 dd = rand(10) .+ 0.1 * i hh = data_tuple
21+
@info "test_2" i = i j = 2^i hh = data_tuple log_step_increment = 0
22+
@info "" my_weird_struct = data_struct log_step_increment = 0
23+
@debug "debug_msg" this_wont_show_up = i
24+
end
25+
end
26+
27+
Dict(:loss => (s, i) -> s + i, :accuracy => (s, i) -> s - i)
Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
module DecisionFocusedLearningAlgorithms
22

3-
# Write your package code here.
3+
using DecisionFocusedLearningBenchmarks
4+
const DVSP = DecisionFocusedLearningBenchmarks.DynamicVehicleScheduling
5+
using Flux: Flux, Adam
6+
using InferOpt: InferOpt, FenchelYoungLoss, PerturbedAdditive
7+
using MLUtils: splitobs
8+
using ProgressMeter: @showprogress
9+
using UnicodePlots: lineplot
10+
11+
include("utils/metrics.jl")
12+
include("fyl.jl")
13+
include("dagger.jl")
14+
15+
export fyl_train_model!,
16+
fyl_train_model, baty_train_model, DAgger_train_model!, DAgger_train_model
417

518
end

src/dagger.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
function DAgger_train_model!(
3+
model,
4+
maximizer,
5+
train_environments,
6+
validation_environments,
7+
anticipative_policy;
8+
iterations=5,
9+
fyl_epochs=3,
10+
metrics_callbacks::NamedTuple=NamedTuple(),
11+
)
12+
α = 1.0
13+
train_dataset = vcat(map(train_environments) do env
14+
v, y = anticipative_policy(env; reset_env=true)
15+
return y
16+
end...)
17+
val_dataset = vcat(map(validation_environments) do env
18+
v, y = anticipative_policy(env; reset_env=true)
19+
return y
20+
end...)
21+
22+
dataset = deepcopy(train_dataset)
23+
all_metrics = []
24+
for iter in 1:iterations
25+
println("DAgger iteration $iter")
26+
metrics = fyl_train_model!(
27+
model,
28+
maximizer,
29+
dataset,
30+
val_dataset;
31+
epochs=fyl_epochs,
32+
metrics_callbacks=metrics_callbacks,
33+
)
34+
push!(all_metrics, metrics)
35+
new_samples = eltype(dataset)[]
36+
# Dataset update
37+
for env in train_environments
38+
reset!(env; reset_rng=false)
39+
while !is_terminated(env)
40+
x_before = copy(observe(env)[1])
41+
_, anticipative_solution = anticipative_policy(env; reset_env=false)
42+
p = rand()
43+
target = anticipative_solution[1]
44+
x, state = observe(env)
45+
if size(target.x) != size(x)
46+
@error "Mismatch between expert and observed state" size(target.x) size(
47+
x
48+
)
49+
end
50+
push!(new_samples, target)
51+
if p < α
52+
action = target.y_true
53+
else
54+
x, state = observe(env)
55+
θ = model(x)
56+
action = maximizer(θ; instance=state) # ! not benchmark generic
57+
end
58+
step!(env, action)
59+
end
60+
end
61+
dataset = new_samples # TODO: replay buffer
62+
α *= 0.9 # Decay factor for mixing expert and learned policy
63+
end
64+
65+
return _flatten_dagger_metrics(all_metrics)
66+
end
67+
68+
function DAgger_train_model(b::AbstractStochasticBenchmark{true}; kwargs...)
69+
dataset = generate_dataset(b, 30)
70+
train_instances, validation_instances, test_instances = dataset[1:10],
71+
dataset[11:20],
72+
dataset[21:30]
73+
train_environments = generate_environments(b, train_instances; seed=0)
74+
validation_environments = generate_environments(b, validation_instances)
75+
model = generate_statistical_model(b)
76+
maximizer = generate_maximizer(b)
77+
anticipative_policy =
78+
(env; reset_env) -> generate_anticipative_solution(b, env; reset_env)
79+
return DAgger_train_model!(
80+
model,
81+
maximizer,
82+
train_environments,
83+
validation_environments,
84+
anticipative_policy;
85+
kwargs...,
86+
)
87+
end

0 commit comments

Comments
 (0)