Skip to content

Commit e8be496

Browse files
committed
Dynamic assortment is in a good state; fix docs; working on DVSP
1 parent a90a3d5 commit e8be496

File tree

13 files changed

+255
-170
lines changed

13 files changed

+255
-170
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ authors = ["Members of JuliaDecisionFocusedLearning"]
44
version = "0.2.4"
55

66
[deps]
7-
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
8-
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
97
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
8+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
109
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
1110
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
1211
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -35,9 +34,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3534
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3635

3736
[compat]
38-
Combinatorics = "1.0.3"
39-
CommonRLInterface = "0.3.3"
4037
Colors = "0.13.1"
38+
Combinatorics = "1.0.3"
4139
ConstrainedShortestPaths = "0.6.0"
4240
DataDeps = "0.7"
4341
Distributions = "0.25"

docs/src/api/dynamic_assorment.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Dynamic Assortment
2+
3+
## Public
4+
5+
```@autodocs
6+
Modules = [DecisionFocusedLearningBenchmarks.DynamicAssortment]
7+
Private = false
8+
```
9+
10+
## Private
11+
12+
```@autodocs
13+
Modules = [DecisionFocusedLearningBenchmarks.DynamicAssortment]
14+
Public = false
15+
```
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Dynamic Assortment
2+
3+
[`DynamicAssortmentBenchmark`](@ref).

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,16 @@ include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
5959
include("DynamicAssortment/DynamicAssortment.jl")
6060

6161
using .Utils
62-
using .Argmax
63-
using .Argmax2D
64-
using .Ranking
65-
using .SubsetSelection
66-
using .Warcraft
67-
using .FixedSizeShortestPath
68-
using .PortfolioOptimization
69-
using .StochasticVehicleScheduling
70-
using .DynamicVehicleScheduling
71-
using .DynamicAssortment
7262

7363
# Interface
7464
export AbstractBenchmark, AbstractStochasticBenchmark, AbstractDynamicBenchmark, DataSample
65+
export AbstractEnv, get_seed, is_terminated, observe, reset!, step!
66+
67+
export Policy, run_policy!
7568

7669
export generate_sample, generate_dataset, generate_environments, generate_environment
7770
export generate_scenario
78-
export generate_scenario_generator, generate_anticipative_solver
71+
export generate_policies
7972
export generate_statistical_model
8073
export generate_maximizer, maximizer_kwargs
8174
export generate_anticipative_solution
@@ -86,15 +79,26 @@ export plot_data, plot_instance, plot_solution
8679
export compute_gap
8780

8881
# Export all benchmarks
89-
export ArgmaxBenchmark
82+
using .Argmax
83+
using .Argmax2D
84+
using .Ranking
85+
using .SubsetSelection
86+
using .Warcraft
87+
using .FixedSizeShortestPath
88+
using .PortfolioOptimization
89+
using .StochasticVehicleScheduling
90+
using .DynamicVehicleScheduling
91+
using .DynamicAssortment
92+
9093
export Argmax2DBenchmark
91-
export RankingBenchmark
92-
export SubsetSelectionBenchmark
93-
export WarcraftBenchmark
94+
export ArgmaxBenchmark
95+
export DynamicAssortmentBenchmark
96+
export DynamicVehicleSchedulingBenchmark
9497
export FixedSizeShortestPathBenchmark
9598
export PortfolioOptimizationBenchmark
99+
export RankingBenchmark
96100
export StochasticVehicleSchedulingBenchmark
97-
export DynamicVehicleSchedulingBenchmark
98-
export DynamicAssortmentBenchmark
101+
export SubsetSelectionBenchmark
102+
export WarcraftBenchmark
99103

100104
end # module DecisionFocusedLearningBenchmarks

src/DynamicAssortment/DynamicAssortment.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module DynamicAssortment
22

33
using ..Utils
44

5-
using CommonRLInterface: CommonRLInterface, AbstractEnv
65
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
76
using Distributions: Uniform, Categorical
87
using LinearAlgebra: dot
@@ -62,17 +61,35 @@ function Utils.generate_sample(
6261
return DataSample(; instance=Instance(b, rng))
6362
end
6463

64+
function Utils.generate_statistical_model(b::DynamicAssortmentBenchmark; seed=nothing)
65+
Random.seed!(seed)
66+
d = feature_count(b)
67+
return Chain(Dense(d + 8 => 5), Dense(5 => 1), vec)
68+
end
69+
6570
function Utils.generate_maximizer(b::DynamicAssortmentBenchmark)
6671
return TopKMaximizer(assortment_size(b))
6772
end
6873

6974
function Utils.generate_environment(
70-
::DynamicAssortmentBenchmark,
71-
instance::Instance;
72-
seed=nothing,
73-
rng::AbstractRNG=MersenneTwister(seed),
75+
::DynamicAssortmentBenchmark, instance::Instance, rng::AbstractRNG
7476
)
75-
return Environment(instance; seed=seed, rng=rng)
77+
seed = rand(rng, 1:typemax(Int))
78+
return Environment(instance; seed)
79+
end
80+
81+
function Utils.generate_policies(b::DynamicAssortmentBenchmark)
82+
greedy = Policy(
83+
"Greedy",
84+
"policy that selects the assortment with items with the highest prices",
85+
greedy_policy,
86+
)
87+
expert = Policy(
88+
"Expert",
89+
"policy that selects the assortment with the highest expected revenue",
90+
expert_policy,
91+
)
92+
return (expert, greedy)
7693
end
7794

7895
export DynamicAssortmentBenchmark

src/DynamicAssortment/environment.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Environment for the dynamic assortment problem.
77
$TYPEDFIELDS
88
"""
99
@kwdef mutable struct Environment{I<:Instance,R<:AbstractRNG,S<:Union{Nothing,Int}} <:
10-
AbstractEnv
10+
Utils.AbstractEnvironment
1111
"associated instance"
1212
instance::I
1313
"current step"
@@ -43,23 +43,22 @@ function Environment(instance::Instance; seed=0, rng::AbstractRNG=MersenneTwiste
4343
features=full_features,
4444
d_features=zeros(2, N),
4545
)
46-
CommonRLInterface.reset!(env; reset_seed=true)
46+
Utils.reset!(env; reset_seed=true)
4747
return env
4848
end
4949

50+
Utils.get_seed(env::Environment) = env.seed
5051
customer_choice_model(b::Environment) = customer_choice_model(b.instance)
5152
item_count(b::Environment) = item_count(b.instance)
5253
feature_count(b::Environment) = feature_count(b.instance)
5354
assortment_size(b::Environment) = assortment_size(b.instance)
5455
max_steps(b::Environment) = max_steps(b.instance)
5556
prices(b::Environment) = b.instance.prices
56-
# features(b::Environment) = b.instance.features
57-
# starting_hype_and_saturation(b::Environment) = b.instance.starting_hype_and_saturation
5857

5958
## Basic operations of environment
6059

6160
# Reset the environment
62-
function CommonRLInterface.reset!(env::Environment; reset_seed=false, seed=env.seed)
61+
function Utils.reset!(env::Environment; reset_seed=false, seed=env.seed)
6362
reset_seed && Random.seed!(env.rng, seed)
6463

6564
env.step = 1
@@ -79,18 +78,19 @@ function CommonRLInterface.reset!(env::Environment; reset_seed=false, seed=env.s
7978
return nothing
8079
end
8180

82-
function CommonRLInterface.terminated(env::Environment)
81+
function Utils.is_terminated(env::Environment)
8382
return env.step > max_steps(env)
8483
end
8584

86-
function CommonRLInterface.observe(env::Environment)
85+
function Utils.observe(env::Environment)
8786
delta_features = env.features[2:3, :] .- env.instance.starting_hype_and_saturation
8887
return vcat(
8988
env.features,
9089
env.d_features,
9190
delta_features,
9291
ones(1, item_count(env)) .* (env.step / max_steps(env) * 10),
93-
) #./ 10
92+
) ./ 10,
93+
nothing
9494
end
9595

9696
# Compute the hype vector
@@ -149,9 +149,10 @@ function choice_probabilities(env::Environment, S)
149149
end
150150

151151
# Purchase decision
152-
function CommonRLInterface.act!(env::Environment, S)
152+
function Utils.step!(env::Environment, assortment)
153+
@assert !Utils.is_terminated(env) "Environment is terminated, cannot act!"
153154
r = prices(env)
154-
probs = choice_probabilities(env, S)
155+
probs = choice_probabilities(env, assortment)
155156
item = rand(env.rng, Categorical(probs))
156157
reward = r[item]
157158
buy_item!(env, item)

src/DynamicVehicleScheduling/DynamicVehicleScheduling.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module DynamicVehicleScheduling
33
using ..Utils
44

55
using Base: @kwdef
6-
using CommonRLInterface: CommonRLInterface, AbstractEnv, reset!, terminated, observe, act!
76
using DataDeps: @datadep_str
87
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
98
using Graphs
@@ -39,11 +38,11 @@ include("algorithms/anticipative_solver.jl")
3938
include("learning/features.jl")
4039
include("learning/2d_features.jl")
4140

42-
include("policy/abstract_vsp_policy.jl")
43-
include("policy/greedy_policy.jl")
44-
include("policy/lazy_policy.jl")
45-
include("policy/anticipative_policy.jl")
46-
include("policy/kleopatra_policy.jl")
41+
# include("policy/abstract_vsp_policy.jl")
42+
# include("policy/greedy_policy.jl")
43+
# include("policy/lazy_policy.jl")
44+
# include("policy/anticipative_policy.jl")
45+
# include("policy/kleopatra_policy.jl")
4746

4847
include("maximizer.jl")
4948

@@ -56,13 +55,13 @@ Abstract type for dynamic vehicle scheduling benchmarks.
5655
$TYPEDFIELDS
5756
"""
5857
@kwdef struct DynamicVehicleSchedulingBenchmark <: AbstractDynamicBenchmark{true}
59-
"todo"
58+
"maximum number of customers entering the system per epoch"
6059
max_requests_per_epoch::Int = 10
61-
"todo"
60+
"time between decision and dispatch of a vehicle"
6261
Δ_dispatch::Float64 = 1.0
63-
"todo"
62+
"duration of an epoch"
6463
epoch_duration::Float64 = 1.0
65-
"todo"
64+
"whether to use two-dimensional features"
6665
two_dimensional_features::Bool = false
6766
end
6867

@@ -83,9 +82,10 @@ function Utils.generate_dataset(b::DynamicVehicleSchedulingBenchmark, dataset_si
8382
end
8483

8584
function Utils.generate_environment(
86-
::DynamicVehicleSchedulingBenchmark, instance::Instance; kwargs...
85+
::DynamicVehicleSchedulingBenchmark, instance::Instance, rng::AbstractRNG
8786
)
88-
return DVSPEnv(instance; kwargs...)
87+
seed = rand(rng, 1:typemax(Int))
88+
return DVSPEnv(instance; seed)
8989
end
9090

9191
function Utils.generate_maximizer(::DynamicVehicleSchedulingBenchmark)
@@ -105,7 +105,5 @@ function Utils.generate_anticipative_solution(
105105
end
106106

107107
export DynamicVehicleSchedulingBenchmark
108-
export run_policy!,
109-
GreedyVSPPolicy, LazyVSPPolicy, KleopatraVSPPolicy, AnticipativeVSPPolicy
110108

111109
end

src/DynamicVehicleScheduling/algorithms/prize_collecting_vsp.jl

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -126,90 +126,3 @@ function prize_collecting_vsp(
126126

127127
return retrieve_routes(value.(y), graph)
128128
end
129-
130-
# # ?
131-
# function prize_collecting_vsp_Q(
132-
# θ::AbstractVector,
133-
# vals::AbstractVector;
134-
# instance::DVSPState,
135-
# model_builder=highs_model,
136-
# kwargs...,
137-
# )
138-
# (; duration) = instance.instance
139-
# graph = create_graph(instance)
140-
# model = model_builder()
141-
# set_silent(model)
142-
# nb_nodes = nv(graph)
143-
# job_indices = 2:(nb_nodes)
144-
# @variable(model, y[i=1:nb_nodes, j=1:nb_nodes; has_edge(graph, i, j)] >= 0)
145-
# θ_ext = fill(0.0, location_count(instance.instance)) # no prize for must dispatch requests, only hard constraints
146-
# θ_ext[instance.is_postponable] .= θ
147-
# # v_ext = fill(0.0, nb_locations(instance.instance)) # no prize for must dispatch requests, only hard constraints
148-
# # v_ext[instance.is_postponable] .= vals
149-
# @objective(
150-
# model,
151-
# Max,
152-
# sum(
153-
# (θ_ext[dst(edge)] + vals[dst(edge)] - duration[src(edge), dst(edge)]) *
154-
# y[src(edge), dst(edge)] for edge in edges(graph)
155-
# )
156-
# )
157-
# @constraint(
158-
# model,
159-
# flow[i in 2:nb_nodes],
160-
# sum(y[j, i] for j in inneighbors(graph, i)) ==
161-
# sum(y[i, j] for j in outneighbors(graph, i))
162-
# )
163-
# @constraint(
164-
# model, demand[i in job_indices], sum(y[j, i] for j in inneighbors(graph, i)) <= 1
165-
# )
166-
# # must dispatch constraints
167-
# @constraint(
168-
# model,
169-
# demand_must_dispatch[i in job_indices; instance.is_must_dispatch[i]],
170-
# sum(y[j, i] for j in inneighbors(graph, i)) == 1
171-
# )
172-
# optimize!(model)
173-
# return retrieve_routes(value.(y), graph)
174-
# end
175-
176-
# function my_objective_value(θ, routes; instance)
177-
# (; duration) = instance.instance
178-
# total = 0.0
179-
# θ_ext = fill(0.0, location_count(instance))
180-
# θ_ext[instance.is_postponable] .= θ
181-
# for route in routes
182-
# for (u, v) in partition(vcat(1, route), 2, 1)
183-
# total += θ_ext[v] - duration[u, v]
184-
# end
185-
# end
186-
# return -total
187-
# end
188-
189-
# function _objective_value(θ, routes; instance)
190-
# (; duration) = instance.instance
191-
# total = 0.0
192-
# θ_ext = fill(0.0, location_count(instance))
193-
# θ_ext[instance.is_postponable] .= θ
194-
# mapping = cumsum(instance.is_postponable)
195-
# g = falses(length(θ))
196-
# for route in routes
197-
# for (u, v) in partition(vcat(1, route), 2, 1)
198-
# total -= duration[u, v]
199-
# if instance.is_postponable[v]
200-
# total += θ_ext[v]
201-
# g[mapping[v]] = 1
202-
# end
203-
# end
204-
# end
205-
# return -total, g
206-
# end
207-
208-
# function ChainRulesCore.rrule(::typeof(my_objective_value), θ, routes; instance)
209-
# total, g = _objective_value(θ, routes; instance)
210-
# function pullback(dy)
211-
# g = g .* dy
212-
# return NoTangent(), g, NoTangent()
213-
# end
214-
# return total, pullback
215-
# end

0 commit comments

Comments
 (0)