Skip to content

Commit e10008e

Browse files
authored
Merge pull request #19 from JuliaDecisionFocusedLearning/update-interface
Improve interface
2 parents f1efe57 + 8c98324 commit e10008e

File tree

5 files changed

+53
-10
lines changed

5 files changed

+53
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2626
DataDeps = "0.7"
2727
Distributions = "0.25"
2828
DocStringExtensions = "0.9"
29-
Flux = "0.14"
29+
Flux = "0.16"
3030
Graphs = "1.11"
3131
HiGHS = "1.9"
3232
Images = "0.26"

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using .PortfolioOptimization
3838
export AbstractBenchmark, DataSample
3939
export generate_dataset
4040
export generate_statistical_model
41-
export generate_maximizer
41+
export generate_maximizer, maximizer_kwargs
4242
export plot_data
4343
export compute_gap
4444

src/Utils/Utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export DataSample
1515
export AbstractBenchmark
1616
export generate_dataset,
1717
generate_statistical_model, generate_maximizer, plot_data, compute_gap
18+
export maximizer_kwargs
1819
export grid_graph, get_path, path_to_matrix
1920
export neg_tensor, squeeze_last_dims, average_tensor
2021

src/Utils/data_sample.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ Data sample data structure.
66
# Fields
77
$TYPEDFIELDS
88
"""
9-
@kwdef struct DataSample{F,S,C,I}
9+
@kwdef struct DataSample{
10+
F<:AbstractArray,S<:Union{AbstractArray,Nothing},C<:Union{AbstractArray,Nothing},I
11+
}
1012
"features"
1113
x::F
1214
"target cost parameters (optional)"

src/Utils/interface.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,52 @@ function compute_gap end
5555
"""
5656
$TYPEDSIGNATURES
5757
58+
For simple benchmarks where there is no instance object, maximizer does not need any keyword arguments.
59+
"""
60+
function maximizer_kwargs(
61+
::AbstractBenchmark, sample::DataSample{F,S,C,Nothing}
62+
) where {F,S,C}
63+
return NamedTuple()
64+
end
65+
66+
"""
67+
$TYPEDSIGNATURES
68+
69+
For benchmarks where there is an instance object, maximizer needs the instance object as a keyword argument.
70+
"""
71+
function maximizer_kwargs(::AbstractBenchmark, sample::DataSample)
72+
return (; instance=sample.instance)
73+
end
74+
75+
"""
76+
$TYPEDSIGNATURES
77+
5878
Default behaviour of `objective_value`.
5979
"""
60-
function objective_value(::AbstractBenchmark, θ̄::AbstractArray, y::AbstractArray)
61-
return dot(θ̄, y)
80+
function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray)
81+
return dot(θ, y)
82+
end
83+
84+
"""
85+
$TYPEDSIGNATURES
86+
87+
Compute the objective value of the target in the sample (needs to exist).
88+
"""
89+
function objective_value(
90+
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}
91+
) where {F,S<:AbstractArray,C<:AbstractArray,I}
92+
return objective_value(bench, sample.θ_true, sample.y_true)
93+
end
94+
95+
"""
96+
$TYPEDSIGNATURES
97+
98+
Compute the objective value of given solution `y`.
99+
"""
100+
function objective_value(
101+
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}, y::AbstractArray
102+
) where {F,S,C<:AbstractArray,I}
103+
return objective_value(bench, sample.θ_true, y)
62104
end
63105

64106
"""
@@ -72,13 +114,11 @@ function compute_gap(
72114
res = 0.0
73115

74116
for sample in dataset
117+
target_obj = objective_value(bench, sample)
75118
x = sample.x
76-
θ̄ = sample.θ_true
77-
= sample.y_true
78119
θ = statistical_model(x)
79-
y = maximizer(θ)
80-
target_obj = objective_value(bench, θ̄, ȳ)
81-
obj = objective_value(bench, θ̄, y)
120+
y = maximizer(θ; maximizer_kwargs(bench, sample)...)
121+
obj = objective_value(bench, sample, y)
82122
res += (target_obj - obj) / abs(target_obj)
83123
end
84124
return res / length(dataset)

0 commit comments

Comments
 (0)