|
| 1 | +using DecisionFocusedLearningAlgorithms |
| 2 | +using DecisionFocusedLearningBenchmarks |
| 3 | +using MLUtils |
| 4 | +using Statistics |
| 5 | + |
| 6 | +struct KleopatraPolicy{M} |
| 7 | + model::M |
| 8 | +end |
| 9 | + |
| 10 | +function (m::KleopatraPolicy)(env) |
| 11 | + x, instance = observe(env) |
| 12 | + θ = m.model(x) |
| 13 | + return maximizer(θ; instance) |
| 14 | +end |
| 15 | + |
| 16 | +fyl_train_model(ArgmaxBenchmark(); epochs=1000) |
| 17 | +baty_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)) |
| 18 | +DAgger_train_model(DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false)) |
| 19 | + |
| 20 | +b = DynamicVehicleSchedulingBenchmark(; two_dimensional_features=false) |
| 21 | +dataset = generate_dataset(b, 100) |
| 22 | +train_instances, validation_instances, test_instances = splitobs( |
| 23 | + dataset; at=(0.3, 0.3, 0.4) |
| 24 | +) |
| 25 | +train_environments = generate_environments(b, train_instances; seed=0) |
| 26 | +validation_environments = generate_environments(b, validation_instances) |
| 27 | +test_environments = generate_environments(b, test_instances) |
| 28 | + |
| 29 | +train_dataset = vcat(map(train_environments) do env |
| 30 | + v, y = generate_anticipative_solution(b, env; reset_env=true) |
| 31 | + return y |
| 32 | +end...) |
| 33 | + |
| 34 | +val_dataset = vcat(map(validation_environments) do env |
| 35 | + v, y = generate_anticipative_solution(b, env; reset_env=true) |
| 36 | + return y |
| 37 | +end...) |
| 38 | + |
| 39 | +model = generate_statistical_model(b; seed=0) |
| 40 | +maximizer = generate_maximizer(b) |
| 41 | +anticipative_policy = (env; reset_env) -> generate_anticipative_solution(b, env; reset_env) |
| 42 | + |
| 43 | +fyl_model = deepcopy(model) |
| 44 | +fyl_policy = Policy("fyl", "", KleopatraPolicy(fyl_model)) |
| 45 | + |
| 46 | +metrics_callbacks = (; |
| 47 | + obj=(model, maximizer, epoch) -> |
| 48 | + mean(evaluate_policy!(fyl_policy, test_environments, 1)[1]) |
| 49 | +) |
| 50 | + |
| 51 | +fyl_loss = fyl_train_model!( |
| 52 | + fyl_model, maximizer, train_dataset, val_dataset; epochs=100, metrics_callbacks |
| 53 | +) |
| 54 | + |
| 55 | +dagger_model = deepcopy(model) |
| 56 | +dagger_policy = Policy("dagger", "", KleopatraPolicy(dagger_model)) |
| 57 | +metrics_callbacks = (; |
| 58 | + obj=(model, maximizer, epoch) -> |
| 59 | + mean(evaluate_policy!(dagger_policy, test_environments, 1)[1]) |
| 60 | +) |
| 61 | +dagger_loss = DAgger_train_model!( |
| 62 | + dagger_model, |
| 63 | + maximizer, |
| 64 | + train_environments, |
| 65 | + validation_environments, |
| 66 | + anticipative_policy; |
| 67 | + iterations=10, |
| 68 | + fyl_epochs=10, |
| 69 | + metrics_callbacks, |
| 70 | +) |
| 71 | + |
| 72 | +plot( |
| 73 | + 0:100, |
| 74 | + [fyl_loss.obj[1:end], dagger_loss.obj[1:end]]; |
| 75 | + labels=["FYL" "DAgger"], |
| 76 | + xlabel="Epoch", |
| 77 | + ylabel="Test Average Reward (1 scenario)", |
| 78 | +) |
| 79 | + |
| 80 | +using Statistics |
| 81 | +v_fyl, _ = evaluate_policy!(fyl_policy, test_environments, 100) |
| 82 | +v_dagger, _ = evaluate_policy!(dagger_policy, test_environments, 100) |
| 83 | +mean(v_fyl) |
| 84 | +mean(v_dagger) |
| 85 | + |
| 86 | +anticipative_policy(test_environments[1]; reset_env=true) |
0 commit comments