Skip to content

Commit e8057cf

Browse files
committed
wip
1 parent 0c7e20a commit e8057cf

File tree

1 file changed

+56
-55
lines changed

1 file changed

+56
-55
lines changed

src/DynamicAssortment/environment.jl

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ $TYPEDFIELDS
1111
"""
1212
@kwdef struct Instance{M}
1313
"customer choice model"
14-
customer_choice_model::M = Chain(Dense([0.3 0.5 0.6 -0.4 -0.8 0.0]), vec)
14+
customer_choice_model::M = Chain(Dense([0.3 0.5 0.6 -0.4 -0.8]), vec)
1515
"number of items"
1616
N::Int = 20
1717
"dimension of feature vectors (in addition to hype, satisfaction, and price)"
@@ -22,6 +22,7 @@ $TYPEDFIELDS
2222
max_steps::Int = 80
2323
"flags if the environment is endogenous"
2424
endogenous::Bool = true
25+
# start_features?
2526
end
2627

2728
@kwdef mutable struct Environment{R<:AbstractRNG} <: AbstractEnv
@@ -58,7 +59,7 @@ function Environment(
5859
seed=seed,
5960
utility=zeros(instance.N),
6061
prices=zeros(instance.N + 1),
61-
features=zeros(instance.d + 4, instance.N),
62+
features=zeros(instance.d + 3, instance.N),
6263
start_features=zeros(2, instance.N),
6364
d_features=zeros(2, instance.N),
6465
)
@@ -75,7 +76,7 @@ function CommonRLInterface.reset!(env::Environment; reset_seed=false, seed=env.s
7576
(; d, N, customer_choice_model) = env.instance
7677
features = rand(env.rng, Uniform(1.0, 10.0), (d + 3, N))
7778
env.prices = vcat(features[end, :], 0.0)
78-
features = vcat(features, ones(1, N))
79+
# features = vcat(features, ones(1, N)) # TODO
7980
env.d_features .= 0.0
8081
env.step = 1
8182
env.utility .= customer_choice_model(features)
@@ -128,7 +129,7 @@ function step!(env::Environment, item)
128129
hype_vector = hype_update!(env)
129130
env.features[3, :] .*= hype_vector
130131
item != 0 ? env.features[4, item] *= 1.01 : nothing
131-
env.features[6, :] .+= 9 / env.instance.max_steps # ??
132+
# env.features[6, :] .+= 9 / env.instance.max_steps # ??
132133
end
133134
env.d_features = env.features[3:4, :] - old_features[3:4, :] # ! hardcoded everywhere :(
134135
env.step += 1
@@ -146,13 +147,14 @@ function choice_probabilities(env::Environment, S)
146147
end
147148

148149
# Purchase decision
149-
function purchase!(env::Environment, S)
150+
function CommonRLInterface.act!(env::Environment, S)
150151
r = env.prices
151152
probs = choice_probabilities(env, S)
152153
item = rand(env.rng, Categorical(probs))
154+
reward = r[item]
153155
item == env.instance.N + 1 ? item = 0 : item # TODO: cleanup this, not really needed and confusing
154-
item != 0 ? revenue = r[item] : revenue = 0.0
155-
return item, revenue
156+
step!(env, item)
157+
return reward
156158
end
157159

158160
# enumerate all possible assortments of size K and return the best one
@@ -199,9 +201,8 @@ function expert_policy(env::Environment, episodes; first_seed=1, use_oracle=fals
199201
feature_vector = vcat(env.features, env.d_features, delta_features)
200202
push!(training_instances, (features=feature_vector, S_t=S))
201203

202-
item, revenue = purchase!(env, S)
203-
rev_episode += revenue
204-
step!(env, item)
204+
reward = CommonRLInterface.act!(env, S)
205+
rev_episode += reward
205206

206207
env.step > env.instance.max_steps ? done = true : done = false
207208
end
@@ -224,48 +225,48 @@ function model_random(features)
224225
end
225226

226227
# Episode generation
227-
function generate_episode(env::Environment, model, customer_model, sigma, random_seed)
228-
buffer = []
229-
start_features, d_features = reset!(env; seed=random_seed)
230-
features = copy(start_features)
231-
done = false
232-
while !done
233-
delta_features = features[3:4, :] .- start_features[3:4, :]
234-
r = features[5, :]
235-
feature_vector = vcat(features, d_features, delta_features)
236-
θ = model(feature_vector)
237-
η = rand(MersenneTwister(random_seed * env.step), p(θ, sigma), 1)[:, 1]
238-
S = DAP_optimization(η; instance=env.instance)
239-
θ_0 = customer_model(features)
240-
item, revenue = purchase!(env, S)
241-
features, d_features = step!(env, features, item)
242-
feat_next = vcat(features, d_features, features[3:4, :] .- start_features[3:4, :])
243-
push!(
244-
buffer,
245-
(
246-
t=env.step - 1,
247-
feat_t=feature_vector,
248-
theta=θ,
249-
eta=η,
250-
S_t=S,
251-
a_T=item,
252-
rev_t=revenue,
253-
ret_t=0.0,
254-
feat_next=feat_next,
255-
),
256-
)
257-
count(!iszero, inventory) < env.instance.K ? break : nothing
258-
env.step > env.instance.max_steps ? done = true : done = false
259-
end
260-
for i in (length(buffer) - 1):-1:1
261-
if i == length(buffer) - 1
262-
ret = buffer[i].rev_t
263-
else
264-
ret = buffer[i].rev_t + 0.99 * buffer[i + 1].ret_t
265-
end
266-
traj = buffer[i]
267-
traj_updated = (; traj..., ret_t=ret)
268-
buffer[i] = traj_updated
269-
end
270-
return buffer
271-
end
228+
# function generate_episode(env::Environment, model, customer_model, sigma, random_seed)
229+
# buffer = []
230+
# start_features, d_features = reset!(env; seed=random_seed)
231+
# features = copy(start_features)
232+
# done = false
233+
# while !done
234+
# delta_features = features[3:4, :] .- start_features[3:4, :]
235+
# r = features[5, :]
236+
# feature_vector = vcat(features, d_features, delta_features)
237+
# θ = model(feature_vector)
238+
# η = rand(MersenneTwister(random_seed * env.step), p(θ, sigma), 1)[:, 1]
239+
# S = DAP_optimization(η; instance=env.instance)
240+
# θ_0 = customer_model(features)
241+
# item, revenue = purchase!(env, S)
242+
# features, d_features = step!(env, item)
243+
# feat_next = vcat(features, d_features, features[3:4, :] .- start_features[3:4, :])
244+
# push!(
245+
# buffer,
246+
# (
247+
# t=env.step - 1,
248+
# feat_t=feature_vector,
249+
# theta=θ,
250+
# eta=η,
251+
# S_t=S,
252+
# a_T=item,
253+
# rev_t=revenue,
254+
# ret_t=0.0,
255+
# feat_next=feat_next,
256+
# ),
257+
# )
258+
# count(!iszero, inventory) < env.instance.K ? break : nothing
259+
# env.step > env.instance.max_steps ? done = true : done = false
260+
# end
261+
# for i in (length(buffer) - 1):-1:1
262+
# if i == length(buffer) - 1
263+
# ret = buffer[i].rev_t
264+
# else
265+
# ret = buffer[i].rev_t + 0.99 * buffer[i + 1].ret_t
266+
# end
267+
# traj = buffer[i]
268+
# traj_updated = (; traj..., ret_t=ret)
269+
# buffer[i] = traj_updated
270+
# end
271+
# return buffer
272+
# end

0 commit comments

Comments
 (0)