Skip to content

Commit 0c7e20a

Browse files
committed
first version of DynamicAssortmentBenchmark
1 parent b30fe33 commit 0c7e20a

File tree

4 files changed

+312
-0
lines changed

4 files changed

+312
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Members of JuliaDecisionFocusedLearning"]
44
version = "0.2.2"
55

66
[deps]
7+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
78
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
89
ConstrainedShortestPaths = "b3798467-87dc-4d99-943d-35a1bd39e395"
910
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
@@ -32,6 +33,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3233
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3334

3435
[compat]
36+
Combinatorics = "1.0.3"
3537
CommonRLInterface = "0.3.3"
3638
ConstrainedShortestPaths = "0.6.0"
3739
DataDeps = "0.7"

src/DecisionFocusedLearningBenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ include("FixedSizeShortestPath/FixedSizeShortestPath.jl")
5555
include("PortfolioOptimization/PortfolioOptimization.jl")
5656
include("StochasticVehicleScheduling/StochasticVehicleScheduling.jl")
5757
include("DynamicVehicleScheduling/DynamicVehicleScheduling.jl")
58+
include("DynamicAssortment/DynamicAssortment.jl")
5859

5960
using .Utils
6061
using .Argmax
@@ -65,6 +66,7 @@ using .FixedSizeShortestPath
6566
using .PortfolioOptimization
6667
using .StochasticVehicleScheduling
6768
using .DynamicVehicleScheduling
69+
using .DynamicAssortment
6870

6971
# Interface
7072
export AbstractBenchmark, AbstractStochasticBenchmark, AbstractDynamicBenchmark, DataSample
@@ -87,5 +89,6 @@ export FixedSizeShortestPathBenchmark
8789
export PortfolioOptimizationBenchmark
8890
export StochasticVehicleSchedulingBenchmark
8991
export DVSPBenchmark
92+
export DynamicAssortmentBenchmark
9093

9194
end # module DecisionFocusedLearningBenchmarks
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module DynamicAssortment
2+
3+
using ..Utils
4+
5+
using CommonRLInterface: CommonRLInterface, AbstractEnv
6+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
7+
using Distributions: Uniform, Categorical
8+
using LinearAlgebra: dot
9+
using Random: Random, AbstractRNG, MersenneTwister
10+
using Statistics: mean
11+
12+
using Flux: Chain, Dense
13+
# using Flux.Optimise
14+
# using InferOpt
15+
# using Random
16+
# using JLD2
17+
# using Plots
18+
# using Distributions
19+
# using LinearAlgebra
20+
using Combinatorics: combinations
21+
22+
include("environment.jl")
23+
24+
struct DynamicAssortmentBenchmark <: AbstractDynamicBenchmark end
25+
26+
function Utils.generate_sample(::DynamicAssortmentBenchmark)
27+
return DataSample(; instance=Instance())
28+
end
29+
30+
function Utils.generate_maximizer(::DynamicAssortmentBenchmark)
31+
return DAP_optimization
32+
end
33+
34+
export DynamicAssortmentBenchmark
35+
36+
end
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
$TYPEDEF
3+
4+
Feature 1:d Random static feature
5+
Feature 3: Hype
6+
Feature 4: Satisfaction
7+
Feature 5: Price
8+
9+
# Fields
10+
$TYPEDFIELDS
11+
"""
12+
@kwdef struct Instance{M}
13+
"customer choice model"
14+
customer_choice_model::M = Chain(Dense([0.3 0.5 0.6 -0.4 -0.8 0.0]), vec)
15+
"number of items"
16+
N::Int = 20
17+
"dimension of feature vectors (in addition to hype, satisfaction, and price)"
18+
d::Int = 2
19+
"assortment size constraint"
20+
K::Int = 4
21+
"number of steps per episode"
22+
max_steps::Int = 80
23+
"flags if the environment is endogenous"
24+
endogenous::Bool = true
25+
end
26+
27+
@kwdef mutable struct Environment{R<:AbstractRNG} <: AbstractEnv
28+
"associated instance"
29+
instance::Instance
30+
"current step"
31+
step::Int
32+
"purchase history"
33+
purchase_hist::Vector{Int}
34+
"rng"
35+
rng::R
36+
"seed for RNG"
37+
seed::Int
38+
"customer utility for each item"
39+
utility::Vector{Float64}
40+
"prices for each item"
41+
prices::Vector{Float64}
42+
"current full features"
43+
features::Matrix{Float64}
44+
"starting satisfaction + hype features"
45+
start_features::Matrix{Float64}
46+
"satisfaction + hype feature change from the last step"
47+
d_features::Matrix{Float64}
48+
end
49+
50+
function Environment(
51+
instance::Instance; seed::Int=0, rng::AbstractRNG=MersenneTwister(seed)
52+
)
53+
return Environment(;
54+
instance=instance,
55+
step=1,
56+
purchase_hist=Int[],
57+
rng=rng,
58+
seed=seed,
59+
utility=zeros(instance.N),
60+
prices=zeros(instance.N + 1),
61+
features=zeros(instance.d + 4, instance.N),
62+
start_features=zeros(2, instance.N),
63+
d_features=zeros(2, instance.N),
64+
)
65+
end
66+
67+
## Basic operations of environment
68+
69+
# Reset the environment
70+
function CommonRLInterface.reset!(env::Environment; reset_seed=false, seed=env.seed)
71+
env.seed = seed
72+
if reset_seed
73+
Random.seed!(env.rng, env.seed)
74+
end
75+
(; d, N, customer_choice_model) = env.instance
76+
features = rand(env.rng, Uniform(1.0, 10.0), (d + 3, N))
77+
env.prices = vcat(features[end, :], 0.0)
78+
features = vcat(features, ones(1, N))
79+
env.d_features .= 0.0
80+
env.step = 1
81+
env.utility .= customer_choice_model(features)
82+
env.features .= features
83+
env.start_features .= features[(d + 1):(d + 2), :]
84+
env.purchase_hist = Int[]
85+
return nothing
86+
end
87+
88+
# Update the hype vector
89+
function hype_update!(env::Environment)
90+
hype_vector = ones(env.instance.N)
91+
env.purchase_hist[end] != 0 ? hype_vector[env.purchase_hist[end]] += 0.02 : nothing
92+
if length(env.purchase_hist) >= 2
93+
if env.purchase_hist[end - 1] != 0
94+
hype_vector[env.purchase_hist[end - 1]] -= 0.005
95+
else
96+
nothing
97+
end
98+
if length(env.purchase_hist) >= 3
99+
if env.purchase_hist[end - 2] != 0
100+
hype_vector[env.purchase_hist[end - 2]] -= 0.005
101+
else
102+
nothing
103+
end
104+
if length(env.purchase_hist) >= 4
105+
if env.purchase_hist[end - 3] != 0
106+
hype_vector[env.purchase_hist[end - 3]] -= 0.005
107+
else
108+
nothing
109+
end
110+
if length(env.purchase_hist) >= 5
111+
if env.purchase_hist[end - 4] != 0
112+
hype_vector[env.purchase_hist[end - 4]] -= 0.005
113+
else
114+
nothing
115+
end
116+
end
117+
end
118+
end
119+
end
120+
return hype_vector
121+
end
122+
123+
# Step function
124+
function step!(env::Environment, item)
125+
old_features = copy(env.features)
126+
push!(env.purchase_hist, item)
127+
if env.instance.endogenous
128+
hype_vector = hype_update!(env)
129+
env.features[3, :] .*= hype_vector
130+
item != 0 ? env.features[4, item] *= 1.01 : nothing
131+
env.features[6, :] .+= 9 / env.instance.max_steps # ??
132+
end
133+
env.d_features = env.features[3:4, :] - old_features[3:4, :] # ! hardcoded everywhere :(
134+
env.step += 1
135+
return nothing
136+
end
137+
138+
# Choice probabilities
139+
function choice_probabilities(env::Environment, S)
140+
θ = env.utility
141+
exp_values = [exp(θ[i]) * S[i] for i in 1:(env.instance.N)]
142+
denominator = 1 + sum(exp_values)
143+
probs = [exp_values[i] / denominator for i in 1:(env.instance.N)]
144+
push!(probs, 1 / denominator) # Probability of no purchase
145+
return probs
146+
end
147+
148+
# Purchase decision
149+
function purchase!(env::Environment, S)
150+
r = env.prices
151+
probs = choice_probabilities(env, S)
152+
item = rand(env.rng, Categorical(probs))
153+
item == env.instance.N + 1 ? item = 0 : item # TODO: cleanup this, not really needed and confusing
154+
item != 0 ? revenue = r[item] : revenue = 0.0
155+
return item, revenue
156+
end
157+
158+
# enumerate all possible assortments of size K and return the best one
159+
# ? can't we do better than that, probably
160+
function expert_solution(env::Environment)
161+
r = env.prices
162+
local best_S
163+
best_revenue = 0.0
164+
for S in combinations(1:(env.instance.N), env.instance.K)
165+
S_vec = zeros(env.instance.N)
166+
S_vec[S] .= 1.0
167+
probs = choice_probabilities(env, S_vec)
168+
expected_revenue = dot(probs, r)
169+
if expected_revenue > best_revenue
170+
best_S, best_revenue = S_vec, expected_revenue
171+
end
172+
end
173+
return best_S
174+
end
175+
176+
# DAP CO-layer
177+
function DAP_optimization(θ; instance::Instance)
178+
solution = partialsortperm(θ, 1:(instance.K); rev=true) # It never makes sense not to show k items
179+
S = zeros(instance.N)
180+
S[solution] .= 1
181+
return S
182+
end
183+
184+
## Solution functions
185+
186+
# Anticipative (fixed)
187+
function expert_policy(env::Environment, episodes; first_seed=1, use_oracle=false)
188+
dataset = []
189+
rev_global = Float64[]
190+
for i in 1:episodes
191+
rev_episode = 0.0
192+
CommonRLInterface.reset!(env; seed=first_seed - 1 + i, reset_seed=true)
193+
done = false
194+
training_instances = []
195+
while !done
196+
S = expert_solution(env)
197+
198+
delta_features = env.features[3:4, :] .- env.start_features # ! hardcoded
199+
feature_vector = vcat(env.features, env.d_features, delta_features)
200+
push!(training_instances, (features=feature_vector, S_t=S))
201+
202+
item, revenue = purchase!(env, S)
203+
rev_episode += revenue
204+
step!(env, item)
205+
206+
env.step > env.instance.max_steps ? done = true : done = false
207+
end
208+
push!(rev_global, rev_episode)
209+
push!(dataset, training_instances)
210+
end
211+
return mean(rev_global), rev_global, dataset
212+
end
213+
214+
# Greedy heuristic
215+
function model_greedy(features)
216+
model = Chain(Dense([0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0]), vec)
217+
return model(features)
218+
end
219+
220+
# Random heuristic
221+
function model_random(features)
222+
rand_seed = Int(round(sum(features)))
223+
return rand(MersenneTwister(rand_seed), Uniform(0.0, 1.0), size(features)[2])
224+
end
225+
226+
# Episode generation
227+
function generate_episode(env::Environment, model, customer_model, sigma, random_seed)
228+
buffer = []
229+
start_features, d_features = reset!(env; seed=random_seed)
230+
features = copy(start_features)
231+
done = false
232+
while !done
233+
delta_features = features[3:4, :] .- start_features[3:4, :]
234+
r = features[5, :]
235+
feature_vector = vcat(features, d_features, delta_features)
236+
θ = model(feature_vector)
237+
η = rand(MersenneTwister(random_seed * env.step), p(θ, sigma), 1)[:, 1]
238+
S = DAP_optimization(η; instance=env.instance)
239+
θ_0 = customer_model(features)
240+
item, revenue = purchase!(env, S)
241+
features, d_features = step!(env, features, item)
242+
feat_next = vcat(features, d_features, features[3:4, :] .- start_features[3:4, :])
243+
push!(
244+
buffer,
245+
(
246+
t=env.step - 1,
247+
feat_t=feature_vector,
248+
theta=θ,
249+
eta=η,
250+
S_t=S,
251+
a_T=item,
252+
rev_t=revenue,
253+
ret_t=0.0,
254+
feat_next=feat_next,
255+
),
256+
)
257+
count(!iszero, inventory) < env.instance.K ? break : nothing
258+
env.step > env.instance.max_steps ? done = true : done = false
259+
end
260+
for i in (length(buffer) - 1):-1:1
261+
if i == length(buffer) - 1
262+
ret = buffer[i].rev_t
263+
else
264+
ret = buffer[i].rev_t + 0.99 * buffer[i + 1].ret_t
265+
end
266+
traj = buffer[i]
267+
traj_updated = (; traj..., ret_t=ret)
268+
buffer[i] = traj_updated
269+
end
270+
return buffer
271+
end

0 commit comments

Comments
 (0)