Skip to content

Commit 9721b22

Browse files
committed
portfolio optimization tests
1 parent ab34a16 commit 9721b22

File tree

2 files changed

+22
-34
lines changed

2 files changed

+22
-34
lines changed

test/fixed_size_shortest_path.jl

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,11 @@ module FixedSizeShortestPathTest
22

33
using DecisionFocusedLearningBenchmarks.FixedSizeShortestPath
44

5-
# using Flux
6-
# using InferOpt
7-
# using ProgressMeter
8-
# using UnicodePlots
9-
# using Zygote
10-
115
bench = FixedSizeShortestPathBenchmark()
126

137
(; features, costs, solutions) = generate_dataset(bench)
148

159
model = generate_statistical_model(bench)
1610
maximizer = generate_maximizer(bench)
1711

18-
# perturbed = PerturbedAdditive(maximizer; nb_samples=10, ε=0.1)
19-
# fyl = FenchelYoungLoss(perturbed)
20-
21-
# opt_state = Flux.setup(Adam(), model)
22-
# loss_history = Float64[]
23-
# gap_history = Float64[]
24-
# E = 100
25-
# @showprogress for epoch in 1:E
26-
# loss = 0.0
27-
# for (x, y) in zip(features, solutions)
28-
# val, grads = Flux.withgradient(model) do m
29-
# θ = m(x)
30-
# fyl(θ, y)
31-
# end
32-
# loss += val
33-
# Flux.update!(opt_state, model, grads[1])
34-
# end
35-
# push!(loss_history, loss ./ E)
36-
# push!(
37-
# gap_history, compute_gap(bench, model, features, costs, solutions, maximizer) .* 100
38-
# )
39-
# end
40-
41-
# println(lineplot(loss_history; title="Loss"))
42-
# println(lineplot(gap_history; title="Gap"))
43-
4412
end

test/portfolio_optimization.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
11
@testitem "Portfolio Optimization" begin
22
using DecisionFocusedLearningBenchmarks
33

4-
b = PortfolioOptimizationBenchmark()
4+
d = 50
5+
p = 5
6+
b = PortfolioOptimizationBenchmark(; d=d, p=p)
57

6-
dataset = generate_dataset(b, 100)
8+
dataset = generate_dataset(b, 50)
79
model = generate_statistical_model(b)
810
maximizer = generate_maximizer(b)
11+
12+
for sample in dataset
13+
x = sample.x
14+
θ_true = sample.θ
15+
y_true = sample.y
16+
@test size(x) == (p,)
17+
@test length(θ_true) == d
18+
@test length(y_true) == d
19+
@test isnothing(sample.instance)
20+
@test all(y_true .== maximizer(θ_true))
21+
22+
θ = model(x)
23+
@test length(θ) == d
24+
25+
y = maximizer(θ)
26+
@test length(y) == d
27+
@test sum(y) <= 1
28+
end
929
end

0 commit comments

Comments
 (0)