Skip to content

Commit d51fe95

Browse files
committed
Fix tests
1 parent 0a5f643 commit d51fe95

File tree

5 files changed

+59
-62
lines changed

5 files changed

+59
-62
lines changed

src/Argmax2D/Argmax2D.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Flux: Chain, Dense
77
using LaTeXStrings: @L_str
88
using LinearAlgebra: dot, norm
99
using Plots: Plots
10-
using Random: Random, MersenneTwister
10+
using Random: Random, MersenneTwister, AbstractRNG
1111

1212
include("polytope.jl")
1313

@@ -53,20 +53,16 @@ maximizer(θ; instance, kwargs...) = instance[argmax(dot(θ, v) for v in instanc
5353
"""
5454
$TYPEDSIGNATURES
5555
56-
Generate a dataset for the [`Argmax2DBenchmark`](@ref).
56+
Generate a sample for the [`Argmax2DBenchmark`](@ref).
5757
"""
58-
function Utils.generate_dataset(
59-
bench::Argmax2DBenchmark, dataset_size=10; seed=nothing, rng=MersenneTwister(seed)
60-
)
58+
function Utils.generate_sample(bench::Argmax2DBenchmark, rng::AbstractRNG)
6159
(; nb_features, encoder, polytope_vertex_range) = bench
62-
return map(1:dataset_size) do _
63-
x = randn(rng, Float32, nb_features)
64-
θ_true = encoder(x)
65-
θ_true ./= 2 * norm(θ_true)
66-
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
67-
y_true = maximizer(θ_true; instance)
68-
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
69-
end
60+
x = randn(rng, Float32, nb_features)
61+
θ_true = encoder(x)
62+
θ_true ./= 2 * norm(θ_true)
63+
instance = build_polytope(rand(rng, polytope_vertex_range); shift=rand(rng))
64+
y_true = maximizer(θ_true; instance)
65+
return DataSample(; x=x, θ_true=θ_true, y_true=y_true, instance=instance)
7066
end
7167

7268
"""

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848

4949
include("instance.jl")
5050
include("environment.jl")
51+
include("policies.jl")
5152

5253
customer_choice_model(b::DynamicAssortmentBenchmark) = b.customer_choice_model
5354
item_count(b::DynamicAssortmentBenchmark) = b.N
@@ -56,9 +57,8 @@ assortment_size(b::DynamicAssortmentBenchmark) = b.K
5657
max_steps(b::DynamicAssortmentBenchmark) = b.max_steps
5758

5859
function Utils.generate_sample(
59-
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0); seed=nothing
60+
b::DynamicAssortmentBenchmark, rng::AbstractRNG=MersenneTwister(0)
6061
)
61-
Random.seed!(rng, seed)
6262
return DataSample(; instance=Instance(b, rng))
6363
end
6464

src/DynamicAssortment/environment.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -166,45 +166,3 @@ function compute_expected_revenue(env::Environment, S)
166166
expected_revenue = dot(probs, r)
167167
return expected_revenue
168168
end
169-
170-
function expert_solution(env::Environment)
171-
N = item_count(env)
172-
K = assortment_size(env)
173-
best_S = falses(N)
174-
best_revenue = -1.0
175-
S_vec = falses(N)
176-
for S in combinations(1:N, K)
177-
S_vec .= false
178-
S_vec[S] .= true
179-
expected_revenue = compute_expected_revenue(env, S_vec)
180-
if expected_revenue > best_revenue
181-
best_S, best_revenue = copy(S_vec), expected_revenue
182-
end
183-
end
184-
return best_S
185-
end
186-
187-
function greedy_solution(env::Environment)
188-
maximizer = generate_maximizer(env.instance.config)
189-
return maximizer(prices(env))
190-
end
191-
192-
function run_policy(env::Environment, episodes::Int; first_seed=1, policy=expert_solution)
193-
dataset = []
194-
rev_global = Float64[]
195-
for i in 1:episodes
196-
rev_episode = 0.0
197-
CommonRLInterface.reset!(env; seed=first_seed - 1 + i, reset_seed=true)
198-
training_instances = []
199-
while !CommonRLInterface.terminated(env)
200-
S = policy(env)
201-
features = CommonRLInterface.observe(env)
202-
push!(training_instances, DataSample(; x=features, y_true=S))
203-
reward = CommonRLInterface.act!(env, S)
204-
rev_episode += reward
205-
end
206-
push!(rev_global, rev_episode)
207-
push!(dataset, training_instances)
208-
end
209-
return mean(rev_global), rev_global, dataset
210-
end

src/DynamicAssortment/policies.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
function expert_policy(env::Environment)
2+
N = item_count(env)
3+
K = assortment_size(env)
4+
best_S = falses(N)
5+
best_revenue = -1.0
6+
S_vec = falses(N)
7+
for S in combinations(1:N, K)
8+
S_vec .= false
9+
S_vec[S] .= true
10+
expected_revenue = compute_expected_revenue(env, S_vec)
11+
if expected_revenue > best_revenue
12+
best_S, best_revenue = copy(S_vec), expected_revenue
13+
end
14+
end
15+
return best_S
16+
end
17+
18+
function greedy_policy(env::Environment)
19+
maximizer = generate_maximizer(env.instance.config)
20+
return maximizer(prices(env))
21+
end
22+
23+
function run_policy(env::Environment, episodes::Int; first_seed=1, policy=expert_policy)
24+
dataset = []
25+
rev_global = Float64[]
26+
for i in 1:episodes
27+
rev_episode = 0.0
28+
CommonRLInterface.reset!(env; seed=first_seed - 1 + i, reset_seed=true)
29+
training_instances = []
30+
while !CommonRLInterface.terminated(env)
31+
S = policy(env)
32+
features = CommonRLInterface.observe(env)
33+
push!(training_instances, DataSample(; x=features, y_true=S))
34+
reward = CommonRLInterface.act!(env, S)
35+
rev_episode += reward
36+
end
37+
push!(rev_global, rev_episode)
38+
push!(dataset, training_instances)
39+
end
40+
return mean(rev_global), rev_global, dataset
41+
end

src/DynamicVehicleScheduling/algorithms/anticipative_solver.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,11 @@ end
201201
end
202202

203203
function (solver::AnticipativeSolver)(env::DVSPEnv, scenario=env.scenario; reset_env=false)
204-
if solver.is_2D
205-
return anticipative_solver(env, scenario; model_builder=highs_model_2d, reset_env)
206-
else
207-
return anticipative_solver(env, scenario; model_builder=highs_model, reset_env)
208-
end
204+
return anticipative_solver(
205+
env,
206+
scenario;
207+
model_builder=highs_model,
208+
reset_env,
209+
two_dimensional_features=solver.is_2D,
210+
)
209211
end

0 commit comments

Comments
 (0)