Skip to content

Commit d81e625

Browse files
committed
More tests and bugfix
1 parent 90a0e90 commit d81e625

File tree

5 files changed

+20
-11
lines changed

5 files changed

+20
-11
lines changed

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export AbstractBenchmark, DataSample
4545
export generate_dataset
4646
export generate_statistical_model
4747
export generate_maximizer, maximizer_kwargs
48+
export objective_value
4849
export plot_data, plot_instance, plot_solution
4950
export compute_gap
5051

src/PortfolioOptimization/PortfolioOptimization.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ module PortfolioOptimization
22

33
using ..Utils
44
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
5-
using Distributions
5+
using Distributions: Uniform, Bernoulli
66
using Flux: Chain, Dense
7-
using Ipopt
8-
using JuMP
9-
using LinearAlgebra
10-
using Random
7+
using Ipopt: Ipopt
8+
using JuMP: @variable, @objective, @constraint, optimize!, value
9+
using LinearAlgebra: I
10+
using Random: MersenneTwister
1111

1212
"""
1313
$TYPEDEF

src/StochasticVehicleScheduling/solution/solution.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ $TYPEDSIGNATURES
4848
Create a Solution from a BitVector value.
4949
"""
5050
function Solution(value::BitVector, instance::Instance)
51-
graph = instance.graph
51+
(; graph) = instance
5252
nb_tasks = nv(graph)
5353
is_selected = falses(nb_tasks, nb_tasks)
5454
for (i, edge) in enumerate(edges(graph))
5555
if value[i]
56-
is_selected[edge.src, edge.dst] = true
56+
is_selected[src(edge), dst(edge)] = true
5757
end
5858
end
5959

@@ -112,7 +112,7 @@ function solution_from_JuMP_array(x::AbstractArray, graph::AbstractGraph)
112112
sol = falses(ne(graph)) # init
113113

114114
for (a, edge) in enumerate(edges(graph))
115-
if x[edge.src, edge.dst] == 1
115+
if x[src(edge), dst(edge)] >= 0.5
116116
sol[a] = true
117117
end
118118
end
@@ -131,8 +131,7 @@ function path_solution_from_JuMP_array(x::AbstractArray, graph::AbstractGraph)
131131
while current_task < nb_tasks
132132
sol[v_index, current_task - 1] = true
133133
next_tasks = [
134-
i for i in outneighbors(graph, current_task) if
135-
isapprox(x[current_task, i], 1; atol=0.1)
134+
i for i in outneighbors(graph, current_task) if x[current_task, i] >= 0.5
136135
]
137136
# TODO : there is a more efficient way to search for next task (but more dangerous)
138137
if length(next_tasks) == 1

src/Utils/Utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ export maximizer_kwargs
2323
export grid_graph, get_path, path_to_matrix
2424
export neg_tensor, squeeze_last_dims, average_tensor
2525
export scip_model, highs_model
26+
export objective_value
2627

2728
end

test/vsp.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
dataset = generate_dataset(b, N; seed=0)
1111
mip_dataset = generate_dataset(b, N; seed=0, algorithm=compact_mip)
1212
mipl_dataset = generate_dataset(b, N; seed=0, algorithm=compact_linearized_mip)
13-
miplc_dataset = generate_dataset(b, N; seed=0, algorithm=local_search)
13+
local_search_dataset = generate_dataset(b, N; seed=0, algorithm=local_search)
1414
@test length(dataset) == N
1515

1616
figure_1 = plot_instance(b, dataset[1])
@@ -22,6 +22,14 @@
2222
model = generate_statistical_model(b)
2323

2424
gap = compute_gap(b, dataset, model, maximizer)
25+
gap_mip = compute_gap(b, mip_dataset, model, maximizer)
26+
gap_mipl = compute_gap(b, mipl_dataset, model, maximizer)
27+
gap_local_search = compute_gap(b, local_search_dataset, model, maximizer)
28+
29+
@test gap >= 0 && gap_mip >= 0 && gap_mipl >= 0 && gap_local_search >= 0
30+
@test gap_mip gap_mipl rtol = 1e-2
31+
@test gap_mip >= gap_local_search
32+
@test gap_mip >= gap
2533

2634
for sample in dataset
2735
x = sample.x

0 commit comments

Comments
 (0)