Skip to content

Commit 926291a

Browse files
committed
changed again how the policy evaluation works + cleanup
1 parent 2d80305 commit 926291a

File tree

6 files changed

+28
-54
lines changed

6 files changed

+28
-54
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
[![Coverage](https://codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningBenchmarks.jl/branch/main/graph/badge.svg)](https://app.codecov.io/gh/JuliaDecisionFocusedLearning/DecisionFocusedLearningBenchmarks.jl)
77
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle)
88

9+
!!! warning
10+
This package is currently under active development. The API may change in future releases.
11+
Please refer to the [documentation](https://JuliaDecisionFocusedLearning.github.io/DecisionFocusedLearningBenchmarks.jl/stable/) for the latest updates.
12+
913
## What is Decision-Focused Learning?
1014

1115
Decision-focused learning (DFL) is a paradigm that integrates machine learning prediction with combinatorial optimization to make better decisions under uncertainty. Unlike traditional "predict-then-optimize" approaches that optimize prediction accuracy independently of downstream decision quality, DFL directly optimizes end-to-end decision performance.

docs/src/benchmark_interfaces.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ The package defines a hierarchy of three abstract types:
2626

2727
```
2828
AbstractBenchmark
29-
── AbstractStochasticBenchmark{exogenous}
29+
── AbstractStochasticBenchmark{exogenous}
3030
└── AbstractDynamicBenchmark{exogenous}
3131
```
3232

3333
- **`AbstractBenchmark`**: static, single-stage optimization problems
3434
- **`AbstractStochasticBenchmark{exogenous}`**: stochastic, single stage optimization problems
35-
**`AbstractDynamicBenchmark{exogenous}`**: multi-stage sequential decision problems
35+
- **`AbstractDynamicBenchmark{exogenous}`**: multi-stage sequential decision-making problems
3636

3737
The `{exogenous}` type parameter indicates whether uncertainty distribution comes from external sources (`true`) or is influenced by decisions (`false`), which affects available methods.
3838

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,15 @@ function anticipative_solver(
215215
current_epoch=epoch,
216216
)
217217

218+
reward = -cost(state, decode_bitmatrix_to_routes(y_true))
219+
218220
x = if two_dimensional_features
219221
compute_2D_features(state, env.instance)
220222
else
221223
compute_features(state, env.instance)
222224
end
223225

224-
return DataSample(; instance=state, y_true, x)
226+
return DataSample(; instance=(; state, reward), y_true, x)
225227
end
226228

227229
return obj, dataset

src/DynamicVehicleScheduling/plot.jl

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ function plot_routes(
155155
state::DVSPState,
156156
routes::Vector{Vector{Int}};
157157
route_color=nothing,
158-
route_linewidth=2, # Increased from 2 to 3
158+
route_linewidth=2,
159159
route_alpha=0.8,
160160
kwargs...,
161161
)
@@ -191,36 +191,6 @@ function plot_routes(
191191
return fig
192192
end
193193

194-
# """
195-
# $TYPEDSIGNATURES
196-
197-
# Plot a given DVSPState with routes overlaid. This version accepts routes as a single
198-
# vector where routes are separated by depot visits (index 1).
199-
# """
200-
# function plot_routes(state::DVSPState, routes::Vector{Int}; kwargs...)
201-
# # Convert single route vector to vector of route vectors
202-
# route_vectors = Vector{Int}[]
203-
# current_route = Int[]
204-
205-
# for location in routes
206-
# if location == 1 # Depot visit indicates end of route
207-
# if !isempty(current_route)
208-
# push!(route_vectors, copy(current_route))
209-
# empty!(current_route)
210-
# end
211-
# else
212-
# push!(current_route, location)
213-
# end
214-
# end
215-
216-
# # Add the last route if it doesn't end with depot
217-
# if !isempty(current_route)
218-
# push!(route_vectors, current_route)
219-
# end
220-
221-
# return plot_routes(state, route_vectors; kwargs...)
222-
# end
223-
224194
"""
225195
$TYPEDSIGNATURES
226196

src/Utils/policy.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,37 +31,31 @@ $TYPEDSIGNATURES
3131
Run the policy on the environment and return the total reward and a dataset of observations.
3232
By default, the environment is reset before running the policy.
3333
"""
34-
function evaluate_policy!(policy, env::AbstractEnvironment; kwargs...)
34+
function evaluate_policy!(
35+
policy, env::AbstractEnvironment; reset_env=true, seed=get_seed(env), kwargs...
36+
)
37+
if reset_env
38+
reset!(env; reset_rng=true, seed=seed)
39+
end
3540
total_reward = 0.0
3641
local labeled_dataset
3742
while !is_terminated(env)
3843
y = policy(env; kwargs...)
3944
features, state = observe(env)
45+
reward = step!(env, y)
46+
sample = DataSample(;
47+
x=features, y_true=y, instance=(; state=deepcopy(state), reward)
48+
)
4049
if @isdefined labeled_dataset
41-
push!(
42-
labeled_dataset,
43-
DataSample(; x=features, y_true=y, instance=deepcopy(state)),
44-
)
50+
push!(labeled_dataset, sample)
4551
else
46-
labeled_dataset = [DataSample(; x=features, y_true=y, instance=deepcopy(state))]
52+
labeled_dataset = [sample]
4753
end
48-
reward = step!(env, y)
4954
total_reward += reward
5055
end
5156
return total_reward, labeled_dataset
5257
end
5358

54-
# function evaluate_policy!(policy, envs::Vector{<:AbstractEnvironment}; kwargs...)
55-
# E = length(envs)
56-
# rewards = zeros(Float64, E)
57-
# datasets = map(1:E) do e
58-
# reward, dataset = evaluate_policy!(policy, envs[e]; kwargs...)
59-
# rewards[e] = reward
60-
# return dataset
61-
# end
62-
# return rewards, vcat(datasets...)
63-
# end
64-
6559
"""
6660
$TYPEDSIGNATURES
6761
@@ -73,8 +67,12 @@ function evaluate_policy!(
7367
)
7468
total_reward = 0.0
7569
datasets = map(1:episodes) do _i
76-
reset!(env; reset_rng=(_i == 1))
77-
reward, dataset = evaluate_policy!(policy, env; kwargs...)
70+
if _i == 1
71+
reset!(env; reset_rng=true, seed=seed)
72+
else
73+
reset!(env; reset_rng=false)
74+
end
75+
reward, dataset = evaluate_policy!(policy, env; reset_env=false, kwargs...)
7876
total_reward += reward
7977
return dataset
8078
end

whale_shark_128786.mp4

-1.44 MB
Binary file not shown.

0 commit comments

Comments
 (0)