Skip to content

Commit e17355a

Browse files
committed
Merge branch 'update-interface' into StoVSP
2 parents 0573902 + 8c98324 commit e17355a

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using .StochasticVehicleScheduling
4545
export AbstractBenchmark, DataSample
4646
export generate_dataset
4747
export generate_statistical_model
48-
export generate_maximizer
48+
export generate_maximizer, maximizer_kwargs
4949
export plot_data, plot_instance, plot_solution
5050
export compute_gap
5151

src/Utils/Utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export DataSample
1919
export AbstractBenchmark
2020
export generate_dataset,
2121
generate_statistical_model, generate_maximizer, plot_data, compute_gap
22+
export maximizer_kwargs
2223
export grid_graph, get_path, path_to_matrix
2324
export neg_tensor, squeeze_last_dims, average_tensor
2425
export scip_model, highs_model

src/Utils/data_sample.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ Data sample data structure.
66
# Fields
77
$TYPEDFIELDS
88
"""
9-
@kwdef struct DataSample{I,F,S,C}
9+
@kwdef struct DataSample{
10+
I,F<:AbstractArray,S<:Union{AbstractArray,Nothing},C<:Union{AbstractArray,Nothing}
11+
}
1012
"features"
1113
x::F
1214
"target cost parameters (optional)"

src/Utils/interface.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,52 @@ function compute_gap end
5959
"""
6060
$TYPEDSIGNATURES
6161
62+
For simple benchmarks where there is no instance object, maximizer does not need any keyword arguments.
63+
"""
64+
function maximizer_kwargs(
65+
::AbstractBenchmark, sample::DataSample{F,S,C,Nothing}
66+
) where {F,S,C}
67+
return NamedTuple()
68+
end
69+
70+
"""
71+
$TYPEDSIGNATURES
72+
73+
For benchmarks where there is an instance object, maximizer needs the instance object as a keyword argument.
74+
"""
75+
function maximizer_kwargs(::AbstractBenchmark, sample::DataSample)
76+
return (; instance=sample.instance)
77+
end
78+
79+
"""
80+
$TYPEDSIGNATURES
81+
6282
Default behaviour of `objective_value`.
6383
"""
64-
function objective_value(::AbstractBenchmark, θ̄::AbstractArray, y::AbstractArray)
65-
return dot(θ̄, y)
84+
function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray)
85+
return dot(θ, y)
86+
end
87+
88+
"""
89+
$TYPEDSIGNATURES
90+
91+
Compute the objective value of the target in the sample (needs to exist).
92+
"""
93+
function objective_value(
94+
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}
95+
) where {F,S<:AbstractArray,C<:AbstractArray,I}
96+
return objective_value(bench, sample.θ_true, sample.y_true)
97+
end
98+
99+
"""
100+
$TYPEDSIGNATURES
101+
102+
Compute the objective value of given solution `y`.
103+
"""
104+
function objective_value(
105+
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}, y::AbstractArray
106+
) where {F,S,C<:AbstractArray,I}
107+
return objective_value(bench, sample.θ_true, y)
66108
end
67109

68110
"""
@@ -76,13 +118,11 @@ function compute_gap(
76118
res = 0.0
77119

78120
for sample in dataset
121+
target_obj = objective_value(bench, sample)
79122
x = sample.x
80-
θ̄ = sample.θ_true
81-
= sample.y_true
82123
θ = statistical_model(x)
83-
y = maximizer(θ)
84-
target_obj = objective_value(bench, θ̄, ȳ)
85-
obj = objective_value(bench, θ̄, y)
124+
y = maximizer(θ; maximizer_kwargs(bench, sample)...)
125+
obj = objective_value(bench, sample, y)
86126
res += (target_obj - obj) / abs(target_obj)
87127
end
88128
return res / length(dataset)

0 commit comments

Comments
 (0)