@@ -11,7 +11,7 @@ $TYPEDFIELDS
1111"""
1212@kwdef struct Instance{M}
1313 " customer choice model"
14- customer_choice_model:: M = Chain (Dense ([0.3 0.5 0.6 - 0.4 - 0.8 0.0 ]), vec)
14+ customer_choice_model:: M = Chain (Dense ([0.3 0.5 0.6 - 0.4 - 0.8 ]), vec)
1515 " number of items"
1616 N:: Int = 20
1717 " dimension of feature vectors (in addition to hype, satisfaction, and price)"
@@ -22,6 +22,7 @@ $TYPEDFIELDS
2222 max_steps:: Int = 80
2323 " flags if the environment is endogenous"
2424 endogenous:: Bool = true
25+ # start_features?
2526end
2627
2728@kwdef mutable struct Environment{R<: AbstractRNG } <: AbstractEnv
@@ -58,7 +59,7 @@ function Environment(
5859 seed= seed,
5960 utility= zeros (instance. N),
6061 prices= zeros (instance. N + 1 ),
61- features= zeros (instance. d + 4 , instance. N),
62+ features= zeros (instance. d + 3 , instance. N),
6263 start_features= zeros (2 , instance. N),
6364 d_features= zeros (2 , instance. N),
6465 )
@@ -75,7 +76,7 @@ function CommonRLInterface.reset!(env::Environment; reset_seed=false, seed=env.s
7576 (; d, N, customer_choice_model) = env. instance
7677 features = rand (env. rng, Uniform (1.0 , 10.0 ), (d + 3 , N))
7778 env. prices = vcat (features[end , :], 0.0 )
78- features = vcat (features, ones (1 , N))
79+ # features = vcat(features, ones(1, N)) # TODO
7980 env. d_features .= 0.0
8081 env. step = 1
8182 env. utility .= customer_choice_model (features)
@@ -128,7 +129,7 @@ function step!(env::Environment, item)
128129 hype_vector = hype_update! (env)
129130 env. features[3 , :] .*= hype_vector
130131 item != 0 ? env. features[4 , item] *= 1.01 : nothing
131- env. features[6 , :] .+ = 9 / env. instance. max_steps # ??
132+ # env.features[6, :] .+= 9 / env.instance.max_steps # ??
132133 end
133134 env. d_features = env. features[3 : 4 , :] - old_features[3 : 4 , :] # ! hardcoded everywhere :(
134135 env. step += 1
@@ -146,13 +147,14 @@ function choice_probabilities(env::Environment, S)
146147end
147148
148149# Purchase decision
149- function purchase ! (env:: Environment , S)
150+ function CommonRLInterface . act ! (env:: Environment , S)
150151 r = env. prices
151152 probs = choice_probabilities (env, S)
152153 item = rand (env. rng, Categorical (probs))
154+ reward = r[item]
153155 item == env. instance. N + 1 ? item = 0 : item # TODO : cleanup this, not really needed and confusing
154- item != 0 ? revenue = r[ item] : revenue = 0.0
155- return item, revenue
156+ step! (env, item)
157+ return reward
156158end
157159
158160# enumerate all possible assortments of size K and return the best one
@@ -199,9 +201,8 @@ function expert_policy(env::Environment, episodes; first_seed=1, use_oracle=fals
199201 feature_vector = vcat (env. features, env. d_features, delta_features)
200202 push! (training_instances, (features= feature_vector, S_t= S))
201203
202- item, revenue = purchase! (env, S)
203- rev_episode += revenue
204- step! (env, item)
204+ reward = CommonRLInterface. act! (env, S)
205+ rev_episode += reward
205206
206207 env. step > env. instance. max_steps ? done = true : done = false
207208 end
@@ -224,48 +225,48 @@ function model_random(features)
224225end
225226
226227# Episode generation
227- function generate_episode (env:: Environment , model, customer_model, sigma, random_seed)
228- buffer = []
229- start_features, d_features = reset! (env; seed= random_seed)
230- features = copy (start_features)
231- done = false
232- while ! done
233- delta_features = features[3 : 4 , :] .- start_features[3 : 4 , :]
234- r = features[5 , :]
235- feature_vector = vcat (features, d_features, delta_features)
236- θ = model (feature_vector)
237- η = rand (MersenneTwister (random_seed * env. step), p (θ, sigma), 1 )[:, 1 ]
238- S = DAP_optimization (η; instance= env. instance)
239- θ_0 = customer_model (features)
240- item, revenue = purchase! (env, S)
241- features, d_features = step! (env, features , item)
242- feat_next = vcat (features, d_features, features[3 : 4 , :] .- start_features[3 : 4 , :])
243- push! (
244- buffer,
245- (
246- t= env. step - 1 ,
247- feat_t= feature_vector,
248- theta= θ,
249- eta= η,
250- S_t= S,
251- a_T= item,
252- rev_t= revenue,
253- ret_t= 0.0 ,
254- feat_next= feat_next,
255- ),
256- )
257- count (! iszero, inventory) < env. instance. K ? break : nothing
258- env. step > env. instance. max_steps ? done = true : done = false
259- end
260- for i in (length (buffer) - 1 ): - 1 : 1
261- if i == length (buffer) - 1
262- ret = buffer[i]. rev_t
263- else
264- ret = buffer[i]. rev_t + 0.99 * buffer[i + 1 ]. ret_t
265- end
266- traj = buffer[i]
267- traj_updated = (; traj... , ret_t= ret)
268- buffer[i] = traj_updated
269- end
270- return buffer
271- end
228+ # function generate_episode(env::Environment, model, customer_model, sigma, random_seed)
229+ # buffer = []
230+ # start_features, d_features = reset!(env; seed=random_seed)
231+ # features = copy(start_features)
232+ # done = false
233+ # while !done
234+ # delta_features = features[3:4, :] .- start_features[3:4, :]
235+ # r = features[5, :]
236+ # feature_vector = vcat(features, d_features, delta_features)
237+ # θ = model(feature_vector)
238+ # η = rand(MersenneTwister(random_seed * env.step), p(θ, sigma), 1)[:, 1]
239+ # S = DAP_optimization(η; instance=env.instance)
240+ # θ_0 = customer_model(features)
241+ # item, revenue = purchase!(env, S)
242+ # features, d_features = step!(env, item)
243+ # feat_next = vcat(features, d_features, features[3:4, :] .- start_features[3:4, :])
244+ # push!(
245+ # buffer,
246+ # (
247+ # t=env.step - 1,
248+ # feat_t=feature_vector,
249+ # theta=θ,
250+ # eta=η,
251+ # S_t=S,
252+ # a_T=item,
253+ # rev_t=revenue,
254+ # ret_t=0.0,
255+ # feat_next=feat_next,
256+ # ),
257+ # )
258+ # count(!iszero, inventory) < env.instance.K ? break : nothing
259+ # env.step > env.instance.max_steps ? done = true : done = false
260+ # end
261+ # for i in (length(buffer) - 1):-1:1
262+ # if i == length(buffer) - 1
263+ # ret = buffer[i].rev_t
264+ # else
265+ # ret = buffer[i].rev_t + 0.99 * buffer[i + 1].ret_t
266+ # end
267+ # traj = buffer[i]
268+ # traj_updated = (; traj..., ret_t=ret)
269+ # buffer[i] = traj_updated
270+ # end
271+ # return buffer
272+ # end
0 commit comments