Skip to content

Commit e4454a5

Browse files
committed
Udated greedy and TS algorithm
1 parent 3289608 commit e4454a5

File tree

7 files changed

+401
-349
lines changed

7 files changed

+401
-349
lines changed

ddopai/_modidx.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,18 +308,12 @@
308308
'ddopai/agents/dynamic_pricing/TS.py'),
309309
'ddopai.agents.dynamic_pricing.TS.TSPolicy.__init__': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.__init__',
310310
'ddopai/agents/dynamic_pricing/TS.py'),
311-
'ddopai.agents.dynamic_pricing.TS.TSPolicy.compute_uncertainty_M': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.compute_uncertainty_m',
312-
'ddopai/agents/dynamic_pricing/TS.py'),
313311
'ddopai.agents.dynamic_pricing.TS.TSPolicy.draw_action': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.draw_action',
314312
'ddopai/agents/dynamic_pricing/TS.py'),
315313
'ddopai.agents.dynamic_pricing.TS.TSPolicy.fit': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.fit',
316314
'ddopai/agents/dynamic_pricing/TS.py'),
317-
'ddopai.agents.dynamic_pricing.TS.TSPolicy.parameter_update': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.parameter_update',
318-
'ddopai/agents/dynamic_pricing/TS.py'),
319315
'ddopai.agents.dynamic_pricing.TS.TSPolicy.reset': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.reset',
320316
'ddopai/agents/dynamic_pricing/TS.py'),
321-
'ddopai.agents.dynamic_pricing.TS.TSPolicy.sample_design_matrix': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.sample_design_matrix',
322-
'ddopai/agents/dynamic_pricing/TS.py'),
323317
'ddopai.agents.dynamic_pricing.TS.TSPolicy.update_task': ( '30_agents/42_DP_agents/ts_agent.html#tspolicy.update_task',
324318
'ddopai/agents/dynamic_pricing/TS.py')},
325319
'ddopai.agents.dynamic_pricing.UCB': { 'ddopai.agents.dynamic_pricing.UCB.UCBAgent': ( '30_agents/42_DP_agents/ucb_agent.html#ucbagent',
@@ -404,12 +398,18 @@
404398
'ddopai/agents/dynamic_pricing/greedy.py'),
405399
'ddopai.agents.dynamic_pricing.greedy.GreedyPolicy.fit': ( '30_agents/42_DP_agents/greedy_agent.html#greedypolicy.fit',
406400
'ddopai/agents/dynamic_pricing/greedy.py'),
407-
'ddopai.agents.dynamic_pricing.greedy.GreedyPolicy.parameter_update': ( '30_agents/42_DP_agents/greedy_agent.html#greedypolicy.parameter_update',
408-
'ddopai/agents/dynamic_pricing/greedy.py'),
409401
'ddopai.agents.dynamic_pricing.greedy.GreedyPolicy.reset': ( '30_agents/42_DP_agents/greedy_agent.html#greedypolicy.reset',
410402
'ddopai/agents/dynamic_pricing/greedy.py'),
411403
'ddopai.agents.dynamic_pricing.greedy.GreedyPolicy.update_task': ( '30_agents/42_DP_agents/greedy_agent.html#greedypolicy.update_task',
412-
'ddopai/agents/dynamic_pricing/greedy.py')},
404+
'ddopai/agents/dynamic_pricing/greedy.py'),
405+
'ddopai.agents.dynamic_pricing.greedy._OLSIncremental': ( '30_agents/42_DP_agents/greedy_agent.html#_olsincremental',
406+
'ddopai/agents/dynamic_pricing/greedy.py'),
407+
'ddopai.agents.dynamic_pricing.greedy._OLSIncremental.__init__': ( '30_agents/42_DP_agents/greedy_agent.html#_olsincremental.__init__',
408+
'ddopai/agents/dynamic_pricing/greedy.py'),
409+
'ddopai.agents.dynamic_pricing.greedy._OLSIncremental.theta_hat': ( '30_agents/42_DP_agents/greedy_agent.html#_olsincremental.theta_hat',
410+
'ddopai/agents/dynamic_pricing/greedy.py'),
411+
'ddopai.agents.dynamic_pricing.greedy._OLSIncremental.update': ( '30_agents/42_DP_agents/greedy_agent.html#_olsincremental.update',
412+
'ddopai/agents/dynamic_pricing/greedy.py')},
413413
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP': { 'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpagent',
414414
'ddopai/agents/dynamic_pricing/inventory_constrained/IDP.py'),
415415
'ddopai.agents.dynamic_pricing.inventory_constrained.IDP.IDPAgent.__init__': ( '30_agents/42_DP_agents/421_DP_inventory_agents/idp_agent.html#idpagent.__init__',

ddopai/agents/dynamic_pricing/TS.py

Lines changed: 93 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -24,111 +24,107 @@
2424

2525

2626
# %% ../../../nbs/30_agents/42_DP_agents/12_TS_agent.ipynb 4
27-
class TSPolicy():
27+
class TSPolicy:
28+
"""
29+
Minimal Thompson-Sampling agent for the linear-demand model
30+
D = x⊤α + p · x⊤β + ε
31+
– Gaussian prior θ∼N(0,λ⁻¹I)
32+
– Incremental ridge update of M_t = λI + Σ z zᵀ and θ̂_t = M_t⁻¹ q_t
33+
– One Gaussian posterior draw at each round, priced by p* = -a / (2b)
34+
"""
35+
36+
# ---------- ctor ---------------------------------------------------------
2837
def __init__(self,
2938
lam: float,
30-
reg: float,
3139
environment_info: MDPInfo,
32-
obsprocessors: Optional[List[object]] = None,
33-
actionprocessors: Optional[List[object]] = None,
34-
agent_name: str | None = None,
35-
ex_prices: np.ndarray | None = None,
36-
alpha: np.ndarray | None = None,
37-
beta: np.ndarray | None = None,
38-
price_function = None,
39-
g = None,
40-
):
41-
assert type(alpha) == type(beta), "alpha and beta must be of the same type"
42-
if alpha is None:
43-
alpha = np.zeros(environment_info.observation_space['features'].shape[0])
44-
beta = np.zeros(environment_info.observation_space['features'].shape[0])
45-
if isinstance(ex_prices, list):
46-
ex_prices = np.array(ex_prices)
47-
assert ex_prices.shape[0] >= 2
48-
49-
self.environment_info = environment_info
50-
self.ex_prices = ex_prices
51-
self.alpha = alpha
52-
self.beta = beta
53-
self.actionprocessors = actionprocessors
54-
self.price_function = price_function # Needs to return an np array
55-
self.lam = lam
56-
self.reg = reg
57-
self.g = g
58-
self.t = 0
59-
self.X = np.empty((0, environment_info.observation_space['features'].shape[0] * 2))
60-
self.Y = np.empty((0, 1))
61-
self.mode = "train"
62-
self.actionprocessors.append(ClipAction(environment_info.action_space.low, environment_info.action_space.high))
63-
64-
def draw_action(self, observation: np.ndarray):
65-
if self.t in [0, 1]:
66-
price = self.ex_prices[self.t]
40+
price_function, # takes (x, a, b) ➜ price
41+
actionprocessors=None,
42+
warm_start_prices=None,
43+
init_scale=None):
44+
"""
45+
lam : ridge / prior precision λ
46+
price_function(x, a, b) returns the quadratic-optimal price (usually -a/2b)
47+
warm_start_prices : iterable of k ∈{0,1,2,…} initial prices; can be empty
48+
init_scale : exploration std-multiplier; default √d / 25
49+
"""
50+
51+
d_feat = environment_info.observation_space['features'].shape[0]
52+
self.d_param = 2 * d_feat
53+
self.lam = lam
54+
self.scale = init_scale or (np.sqrt(d_feat) / 25.0)
55+
56+
# incremental posterior
57+
self.M_inv = np.eye(self.d_param) / lam if lam else np.eye(self.d_param)
58+
self.q = np.zeros(self.d_param)
59+
60+
# current point estimate
61+
self.alpha = np.zeros(d_feat)
62+
self.beta = np.zeros(d_feat)
63+
64+
# misc
65+
self.env_info = environment_info
66+
self.price_fn = price_function
67+
self.t = 0
68+
self.warm_p = np.asarray(warm_start_prices) if warm_start_prices is not None else np.empty(0)
69+
70+
# processors (only clip)
71+
self.actionprocessors = actionprocessors or []
72+
self.actionprocessors.append(
73+
ClipAction(environment_info.action_space.low,
74+
environment_info.action_space.high)
75+
)
76+
77+
# ---------- draw_action ---------------------------------------------------
78+
def draw_action(self, observation):
79+
x = observation['features']
80+
81+
# warm-start if required
82+
if self.t < self.warm_p.size:
83+
p = self.warm_p[self.t]
6784
else:
68-
X = observation['features']
69-
M = self.compute_uncertainty_M(X)
70-
noise = np.random.multivariate_normal(np.zeros(2), np.identity(2))
71-
M = np.linalg.inv(M)
72-
M = np.linalg.cholesky(M).T
73-
norm = M @ noise
74-
norm = (1/self.environment_info.observation_space['features'].shape[0]) * norm
75-
76-
alpha = self.alpha + norm[0]
77-
beta = self.beta + norm[1]
78-
price = self.price_function(X, alpha, beta)
79-
80-
81-
for processor in self.actionprocessors:
82-
price = processor(price)
83-
84-
return np.array(price)
85-
86-
def sample_design_matrix(self):
87-
I = np.identity(2*self.environment_info.observation_space['features'].shape[0])
88-
I_lamdba = self.lam * I
89-
if self.X.shape[0] == 0:
90-
return I_lamdba
91-
matrix = np.sum([np.outer(x, x.T) for x in self.X], axis=0)
92-
return I_lamdba + matrix
93-
94-
def compute_uncertainty_M(self, x_t):
95-
M = self.sample_design_matrix()
96-
block_matrix = np.block([
97-
[x_t, np.zeros_like(x_t)],
98-
[np.zeros_like(x_t), x_t]
99-
])
100-
M_inverse = np.linalg.inv(M)
101-
102-
projected_matrix = block_matrix @ M_inverse @ block_matrix.T
103-
projected_matrix_inverse = np.linalg.inv(projected_matrix)
104-
return projected_matrix_inverse
105-
106-
def fit(self, X, Y, action):
107-
assert self.mode == "train"
85+
# posterior sample
86+
L = np.linalg.cholesky(self.M_inv)
87+
noise = np.random.randn(self.d_param)
88+
theta_hat = self.M_inv @ self.q + self.scale * (L @ noise)
89+
90+
a = x @ theta_hat[:x.size]
91+
b = x @ theta_hat[x.size:]
92+
b = np.minimum(np.array([-0.01]), b)
93+
a = np.maximum(np.array([0.01]), a)
94+
95+
p = self.price_fn(np.ones_like(a), a, b) # usually -a / (2b)
96+
97+
for proc in self.actionprocessors:
98+
p = proc(p)
99+
return np.array(p, dtype=np.float32)
100+
101+
# ---------- fit -----------------------------------------------------------
102+
def fit(self, X, D, price):
103+
"""Update posterior with (x,p,D)."""
104+
z = np.concatenate([X, X * price]) # length 2d
105+
Mz = self.M_inv @ z
106+
self.M_inv -= np.outer(Mz, Mz) / (1.0 + z @ Mz)
107+
self.q += z * float(D)
108+
109+
θ_hat = self.M_inv @ self.q
110+
split = θ_hat.size // 2
111+
self.alpha, self.beta = θ_hat[:split], θ_hat[split:]
112+
108113
self.t += 1
109-
X = np.concatenate([X, X * action])
110-
self.X = np.vstack([self.X, X])
111-
self.Y = np.vstack([self.Y, Y])
112-
self.parameter_update()
113-
114-
def parameter_update(self):
115-
if self.X.shape[0] < 2:
116-
return
117-
model = sm.GLM(self.Y, self.X, family=sm.families.Binomial())
118-
results = model.fit()
119-
self.alpha = results.params[:self.environment_info.observation_space['features'].shape[0]]
120-
self.beta = results.params[self.environment_info.observation_space['features'].shape[0]:]
121-
114+
115+
# ---------- helpers -------------------------------------------------------
116+
def reset(self):
117+
pass
118+
122119
def update_task(self, env):
120+
"""Start fresh on a new MDP / feature dimension."""
123121
self.environment_info = env.mdp_info
124-
self.X = np.empty((0, self.environment_info.observation_space['features'].shape[0] * 2))
125-
self.Y = np.empty((0, 1))
122+
self.d = self.environment_info.observation_space['features'].shape[0] * 2
123+
self.M_inv = np.eye(self.d) / self.lam if self.lam != 0 else np.eye(self.d)
124+
self.q = np.zeros(self.d)
126125
self.actionprocessors[-1] = ClipAction(self.environment_info.action_space.low, self.environment_info.action_space.high)
127-
self.M = [[np.power(x,2)+i for x in range(0, int(np.sqrt(self.environment_info.horizon)))] for i in range(0, 2)]
128-
self.t = 0
129-
130-
def reset(self):
131-
return
126+
self.t = 0
127+
132128

133129
# %% ../../../nbs/30_agents/42_DP_agents/12_TS_agent.ipynb 5
134130
class TSCoreAgent(Agent):

ddopai/agents/dynamic_pricing/UCB.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ def __init__(self,
5555
self.lam = lam
5656
self.reg = reg
5757
self.t = 0
58-
self.X = np.empty((0, environment_info.observation_space['features'].shape[0] * 2))
59-
self.Y = np.empty((0, 1))
58+
6059
self.mode = "train"
6160
self.actionprocessors.append(ClipAction(environment_info.action_space.low, environment_info.action_space.high))
62-
self.d = environment_info.observation_space['features'].shape[0]
61+
62+
self.d = environment_info.observation_space['features'].shape[0] * 2
63+
self.M_inv = np.eye(self.d) / self.lam if self.lam != 0 else np.eye(self.d)
64+
self.q = np.zeros(self.d)
6365
def draw_action(self, observation):
6466
x = observation['features']
6567
# if self.t in [0, 1]:
@@ -76,8 +78,6 @@ def draw_action(self, observation):
7678
def fit(self, X, Y, action):
7779

7880
Z = np.concatenate([X, X * action])
79-
self.X = np.vstack([self.X, Z])
80-
self.Y = np.vstack([self.Y, Y])
8181
self.parameter_update(Z, Y)
8282
self.t += 1
8383

@@ -86,11 +86,6 @@ def parameter_update(self, z, D_t):
8686
One-step Sherman-Morrison update of the quasi-MLE for the *identity* link g(u)=u
8787
(linear demand). If you keep a general g, replace D_t by the *score* below.
8888
"""
89-
if self.t == 0:
90-
# first call: initialise
91-
d = len(z)
92-
self.M_inv = np.eye(d) / self.lam if self.lam != 0 else np.eye(d)
93-
self.q = np.zeros(d)
9489

9590
# rank-1 update of M_t^{-1}
9691
Mz = self.M_inv @ z
@@ -105,11 +100,6 @@ def parameter_update(self, z, D_t):
105100

106101

107102
def sample_design_matrix(self):
108-
# d = self.environment_info.observation_space['features'].shape[0]
109-
# I = self.lam * np.identity(2 * d)
110-
# if self.X.shape[0] == 0:
111-
# return I
112-
# return I + self.X.T @ self.X
113103
return np.linalg.inv(self.M_inv)
114104

115105
def sample_from_confidence_region(self, theta_hat, M, N=50, gamma=None):
@@ -146,9 +136,9 @@ def max_rev(self, samples, x):
146136
beta = theta[x.shape[0]:]
147137
a = np.dot(x, alpha)
148138
b = np.dot(x, beta)
149-
b = min(-0.01, b)
150-
a = max( 0.01, a)
151-
price = self.price_function(np.ones_like(x), a, b)
139+
b = np.minimum(np.array([-0.01]), b)
140+
a = np.maximum(np.array([0.01]), a)
141+
price = self.price_function(np.ones_like(a), a, b)
152142

153143
rev = price * self.g.g(a + price * b)
154144
if rev > max_val:
@@ -158,9 +148,9 @@ def max_rev(self, samples, x):
158148

159149
def update_task(self, env):
160150
self.environment_info = env.mdp_info
161-
self.d = self.environment_info.observation_space['features'].shape[0]
162-
self.X = np.empty((0, 2 * self.d))
163-
self.Y = np.empty((0, 1))
151+
self.d = self.environment_info.observation_space['features'].shape[0] * 2
152+
self.M_inv = np.eye(self.d) / self.lam if self.lam != 0 else np.eye(self.d)
153+
self.q = np.zeros(self.d)
164154
self.actionprocessors[-1] = ClipAction(self.environment_info.action_space.low, self.environment_info.action_space.high)
165155
self.t = 0
166156

0 commit comments

Comments
 (0)