Skip to content

Commit 223557e

Browse files
authored
Merge pull request #50 from JuliaDecisionFocusedLearning/named-tuple-info
Enforce `info` field to be a NamedTuple
2 parents 2ccbcc7 + 98a7c24 commit 223557e

32 files changed

+259
-174
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
runs-on: ubuntu-latest
2424
strategy:
2525
matrix:
26-
julia-version: ['1']
26+
julia-version: ['1.10', '1']
2727

2828
steps:
2929
- uses: actions/checkout@v5

Project.toml

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ uuid = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
33
authors = ["Members of JuliaDecisionFocusedLearning"]
44
version = "0.3.0"
55

6+
[workspace]
7+
projects = ["docs", "test"]
8+
69
[deps]
710
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
811
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -56,31 +59,12 @@ LinearAlgebra = "1"
5659
Metalhead = "0.9.4"
5760
NPZ = "0.4"
5861
Plots = "1"
59-
Printf = "1.11.0"
62+
Printf = "1"
6063
Random = "1"
6164
Requires = "1.3.0"
6265
SCIP = "0.12"
6366
SimpleWeightedGraphs = "1.4"
6467
SparseArrays = "1"
65-
Statistics = "1.11.1"
68+
Statistics = "1"
6669
StatsBase = "0.34.4"
67-
julia = "1.6"
68-
69-
[extras]
70-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
71-
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
72-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
73-
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
74-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
75-
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
76-
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
77-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
79-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
80-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
81-
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
82-
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
83-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
84-
85-
[targets]
86-
test = ["Aqua", "Documenter", "Flux", "Graphs", "JET", "JuliaFormatter", "Random", "ProgressMeter", "StableRNGs", "Statistics", "Test", "TestItemRunner", "UnicodePlots", "Zygote"]
70+
julia = "1.10"

src/Argmax2D/Argmax2D.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
6262
θ_true ./= 2 * norm(θ_true)
6363
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
6464
y_true = maximizer(θ_true; instance)
65-
return DataSample(; x=x, θ=θ_true, y=y_true, info=instance)
65+
return DataSample(; x=x, θ=θ_true, y=y_true, instance=instance)
6666
end
6767

6868
"""
@@ -88,11 +88,11 @@ function Utils.generate_statistical_model(
8888
return model
8989
end
9090

91-
function Utils.plot_data(::Argmax2DBenchmark; info, θ, kwargs...)
91+
function Utils.plot_data(::Argmax2DBenchmark; instance, θ, kwargs...)
9292
pl = init_plot()
93-
plot_polytope!(pl, info)
93+
plot_polytope!(pl, instance)
9494
plot_objective!(pl, θ)
95-
return plot_maximizer!(pl, θ, info, maximizer)
95+
return plot_maximizer!(pl, θ, instance, maximizer)
9696
end
9797

9898
"""
@@ -101,9 +101,13 @@ $TYPEDSIGNATURES
101101
Plot the data sample for the [`Argmax2DBenchmark`](@ref).
102102
"""
103103
function Utils.plot_data(
104-
bench::Argmax2DBenchmark, sample::DataSample; info=sample.info, θ=sample.θ, kwargs...
104+
bench::Argmax2DBenchmark,
105+
sample::DataSample;
106+
instance=sample.instance,
107+
θ=sample.θ,
108+
kwargs...,
105109
)
106-
return Utils.plot_data(bench; info, θ, kwargs...)
110+
return Utils.plot_data(bench; instance, θ, kwargs...)
107111
end
108112

109113
export Argmax2DBenchmark

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export generate_sample, generate_dataset, generate_environments, generate_enviro
7070
export generate_scenario
7171
export generate_policies
7272
export generate_statistical_model
73-
export generate_maximizer, maximizer_kwargs
73+
export generate_maximizer
7474
export generate_anticipative_solution
7575
export is_exogenous, is_endogenous
7676

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Outputs a data sample containing an [`Instance`](@ref).
8383
function Utils.generate_sample(
8484
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0)
8585
)
86-
return DataSample(; info=Instance(b, rng))
86+
return DataSample(; instance=Instance(b, rng))
8787
end
8888

8989
"""

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
7070
dataset_size = min(dataset_size, length(files))
7171
return [
7272
DataSample(;
73-
info=Instance(
73+
instance=Instance(
7474
read_vsp_instance(files[i]);
7575
max_requests_per_epoch,
7676
Δ_dispatch,

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ function retrieve_routes_anticipative(
2121
current_task = task
2222
while current_task != 1 # < nb_tasks
2323
push!(route, current_task)
24-
local next_task
24+
next_task = -1
2525
for i in 1:nb_tasks
2626
if isapprox(y[current_task, i, t], 1; atol=0.1)
2727
next_task = i
2828
break
2929
end
3030
end
31+
@assert next_task != -1 "No next task found from task $current_task at epoch $t"
3132
current_task = next_task
3233
end
3334
push!(routes[i], route)
@@ -92,14 +93,14 @@ function anticipative_solver(
9293
job_indices = 2:nb_nodes
9394
epoch_indices = T
9495

95-
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true)
96+
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
9697

9798
@objective(
9899
model,
99100
Max,
100101
sum(
101-
-duration[i, j] * y[i, j, t] for
102-
i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices
102+
-duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes,
103+
t in epoch_indices
103104
)
104105
)
105106

@@ -171,12 +172,14 @@ function anticipative_solver(
171172
routes = epoch_routes[i]
172173
epoch_customers = epoch_indices[i]
173174

174-
y_true = VSPSolution(
175-
Vector{Int}[
176-
map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes
177-
];
178-
max_index=length(epoch_customers),
179-
).edge_matrix
175+
y_true =
176+
VSPSolution(
177+
Vector{Int}[
178+
map(idx -> findfirst(==(idx), epoch_customers), route) for
179+
route in routes
180+
];
181+
max_index=length(epoch_customers),
182+
).edge_matrix
180183

181184
location_indices = customer_index[epoch_customers]
182185
new_coordinates = env.instance.static_instance.coordinate[location_indices]
@@ -200,8 +203,7 @@ function anticipative_solver(
200203
is_must_dispatch[2:end] .= true
201204
else
202205
is_must_dispatch[2:end] .=
203-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
204-
new_start_time[2:end]
206+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
205207
end
206208
is_postponable[2:end] .= .!is_must_dispatch[2:end]
207209
# TODO: avoid code duplication with add_new_customers!
@@ -222,7 +224,7 @@ function anticipative_solver(
222224
compute_features(state, env.instance)
223225
end
224226

225-
return DataSample(; info=(; state, reward), y=y_true, x)
227+
return DataSample(; y=y_true, x, state, reward)
226228
end
227229

228230
return obj, dataset

src/DynamicVehicleScheduling/maximizer.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ function retrieve_routes(y::AbstractArray, graph::AbstractGraph)
6262
current_task = task
6363
while current_task != 1 # < nb_tasks
6464
push!(route, current_task)
65-
local next_task
65+
next_task = -1
6666
for i in outneighbors(graph, current_task)
6767
if isapprox(y[current_task, i], 1; atol=0.1)
6868
next_task = i
6969
break
7070
end
7171
end
72+
@assert next_task != -1 "No next task found from task $current_task"
7273
current_task = next_task
7374
end
7475
push!(routes, route)
@@ -93,7 +94,7 @@ function prize_collecting_vsp(
9394
nb_nodes = nv(graph)
9495
job_indices = 2:(nb_nodes)
9596

96-
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes; has_edge(graph, i, j)] >= 0)
97+
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes; has_edge(graph, i, j)] >= 0)
9798

9899
θ_ext = fill(0.0, location_count(instance)) # no prize for must dispatch requests, only hard constraints
99100
θ_ext[instance.is_postponable] .= θ
@@ -129,7 +130,9 @@ end
129130

130131
function oracle(θ; instance::DVSPState, kwargs...)
131132
routes = prize_collecting_vsp(θ; instance=instance, kwargs...)
132-
return VSPSolution(routes; max_index=location_count(instance.state_instance)).edge_matrix
133+
return VSPSolution(
134+
routes; max_index=location_count(instance.state_instance)
135+
).edge_matrix
133136
end
134137

135138
function g(y; instance, kwargs...)

src/DynamicVehicleScheduling/static_vsp/parsing.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ Normalize all time values by the `normalization` parameter.
1010
function read_vsp_instance(filepath::String; normalization=3600.0, digits=2)
1111
type = Float64 #rounded ? Int : Float64
1212
mode = ""
13-
local edge_weight_type
14-
local edge_weight_format
13+
edge_weight_type = ""
14+
edge_weight_format = ""
1515
duration_matrix = Vector{type}[]
1616
nb_locations = 0
17-
local demand
18-
local service_time
19-
local coordinates
20-
local start_time
17+
demand = type[]
18+
service_time = type[]
19+
coordinates = Matrix{type}(undef, 0, 2)
20+
start_time = type[]
2121

2222
file = open(filepath, "r")
2323
for line in eachline(file)

src/StochasticVehicleScheduling/StochasticVehicleScheduling.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767
function Utils.objective_value(
6868
::StochasticVehicleSchedulingBenchmark, sample::DataSample, y::BitVector
6969
)
70-
return evaluate_solution(y, sample.info)
70+
return evaluate_solution(y, sample.instance)
7171
end
7272

7373
"""
@@ -98,7 +98,7 @@ function Utils.generate_sample(
9898
else
9999
nothing
100100
end
101-
return DataSample(; x, info=instance, y=y_true)
101+
return DataSample(; x, instance, y=y_true)
102102
end
103103

104104
"""
@@ -145,11 +145,12 @@ end
145145
$TYPEDSIGNATURES
146146
"""
147147
function plot_instance(
148-
::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs...
148+
::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs...
149149
)
150-
(; tasks, district_width, width) = sample.info.city
150+
@assert hasproperty(sample.instance, :city) "Sample does not contain city information."
151+
(; tasks, district_width, width) = sample.instance.city
151152
ticks = 0:district_width:width
152-
max_time = maximum(t.end_time for t in sample.info.city.tasks[1:(end - 1)])
153+
max_time = maximum(t.end_time for t in sample.instance.city.tasks[1:(end - 1)])
153154
fig = plot(;
154155
xlabel="x",
155156
ylabel="y",
@@ -204,11 +205,12 @@ end
204205
$TYPEDSIGNATURES
205206
"""
206207
function plot_solution(
207-
::StochasticVehicleSchedulingBenchmark, sample::DataSample{<:Instance{City}}; kwargs...
208+
::StochasticVehicleSchedulingBenchmark, sample::DataSample; kwargs...
208209
)
209-
(; tasks, district_width, width) = sample.info.city
210+
@assert hasproperty(sample.instance, :city) "Sample does not contain city information."
211+
(; tasks, district_width, width) = sample.instance.city
210212
ticks = 0:district_width:width
211-
solution = Solution(sample.y, sample.info)
213+
solution = Solution(sample.y, sample.instance)
212214
path_list = compute_path_list(solution)
213215
fig = plot(;
214216
xlabel="x",

0 commit comments

Comments
 (0)