Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
DataDeps = "0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
Flux = "0.14"
Flux = "0.16"
Graphs = "1.11"
HiGHS = "1.9"
Images = "0.26"
Expand Down
2 changes: 1 addition & 1 deletion src/DecisionFocusedLearningBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using .PortfolioOptimization
export AbstractBenchmark, DataSample
export generate_dataset
export generate_statistical_model
export generate_maximizer
export generate_maximizer, maximizer_kwargs
export plot_data
export compute_gap

Expand Down
1 change: 1 addition & 0 deletions src/Utils/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export DataSample
export AbstractBenchmark
export generate_dataset,
generate_statistical_model, generate_maximizer, plot_data, compute_gap
export maximizer_kwargs
export grid_graph, get_path, path_to_matrix
export neg_tensor, squeeze_last_dims, average_tensor

Expand Down
4 changes: 3 additions & 1 deletion src/Utils/data_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ Data sample data structure.
# Fields
$TYPEDFIELDS
"""
@kwdef struct DataSample{F,S,C,I}
@kwdef struct DataSample{
F<:AbstractArray,S<:Union{AbstractArray,Nothing},C<:Union{AbstractArray,Nothing},I
}
"features"
x::F
"target cost parameters (optional)"
Expand Down
54 changes: 47 additions & 7 deletions src/Utils/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,52 @@
"""
$TYPEDSIGNATURES

For simple benchmarks where there is no instance object, maximizer does not need any keyword arguments.
"""
function maximizer_kwargs(

Check warning on line 60 in src/Utils/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/Utils/interface.jl#L60

Added line #L60 was not covered by tests
::AbstractBenchmark, sample::DataSample{F,S,C,Nothing}
) where {F,S,C}
return NamedTuple()

Check warning on line 63 in src/Utils/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/Utils/interface.jl#L63

Added line #L63 was not covered by tests
end

"""
$TYPEDSIGNATURES

For benchmarks where there is an instance object, maximizer needs the instance object as a keyword argument.
"""
function maximizer_kwargs(::AbstractBenchmark, sample::DataSample)
return (; instance=sample.instance)
end

"""
$TYPEDSIGNATURES

Default behaviour of `objective_value`.
"""
function objective_value(::AbstractBenchmark, θ̄::AbstractArray, y::AbstractArray)
return dot(θ̄, y)
function objective_value(::AbstractBenchmark, θ::AbstractArray, y::AbstractArray)
return dot(θ, y)
end

"""
$TYPEDSIGNATURES

Compute the objective value of the target in the sample (needs to exist).
"""
function objective_value(
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}
) where {F,S<:AbstractArray,C<:AbstractArray,I}
return objective_value(bench, sample.θ_true, sample.y_true)
end

"""
$TYPEDSIGNATURES

Compute the objective value of given solution `y`.
"""
function objective_value(
bench::AbstractBenchmark, sample::DataSample{F,S,C,I}, y::AbstractArray
) where {F,S,C<:AbstractArray,I}
return objective_value(bench, sample.θ_true, y)
end

"""
Expand All @@ -72,13 +114,11 @@
res = 0.0

for sample in dataset
target_obj = objective_value(bench, sample)
x = sample.x
θ̄ = sample.θ_true
ȳ = sample.y_true
θ = statistical_model(x)
y = maximizer(θ)
target_obj = objective_value(bench, θ̄, ȳ)
obj = objective_value(bench, θ̄, y)
y = maximizer(θ; maximizer_kwargs(bench, sample)...)
obj = objective_value(bench, sample, y)
res += (target_obj - obj) / abs(target_obj)
end
return res / length(dataset)
Expand Down
Loading