Skip to content

Commit 98eb135

Browse files
committed
Fixes on most agents and added the option for log demand link in hyper training
1 parent 5000b9b commit 98eb135

File tree

15 files changed

+171
-77
lines changed

15 files changed

+171
-77
lines changed

ddopai/_modidx.py

Lines changed: 26 additions & 24 deletions
Large diffs are not rendered by default.

ddopai/agents/dynamic_pricing/ILQX.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,10 @@ def parameter_update(self):
8686
self.alpha = results.params[:self.environment_info.observation_space['features'].shape[0]]
8787
self.beta = results.params[self.environment_info.observation_space['features'].shape[0]:]
8888

89-
def update_env(self, env):
89+
def update_task(self, env):
9090
self.environment_info = env.mdp_info
9191
self.X = np.empty((0, self.environment_info.observation_space['features'].shape[0] * 2))
9292
self.Y = np.empty((0, 1))
93-
self.actionprocessors[-1] = ClipAction(self.environment_info.action_space.low, self.environment_info.action_space.high)
9493
self.M = [[np.power(x,2)+i for x in range(0, int(np.sqrt(self.environment_info.horizon)))] for i in range(0, 2)]
9594
self.t = 0
9695

@@ -127,8 +126,8 @@ def fit(self, dataset, **kwargs):
127126
action = dataset[0][1]
128127
self.policy.fit(X, Y, action)
129128

130-
def update_env(self, env):
131-
self.policy.update_env(env)
129+
def update_task(self, env):
130+
self.policy.update_task(env)
132131

133132

134133
# %% ../../../nbs/30_agents/42_DP_agents/11_ILQX_agent.ipynb 6
@@ -157,6 +156,6 @@ def __init__(self,
157156
price_function = price_function,
158157
g = g)
159158
super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)
160-
def update_env(self, env: object):
159+
def update_task(self, env: object):
161160
""" Update the environment specific parameters of the agent """
162-
self.agent.update_env(env)
161+
self.agent.update_task(env)

ddopai/agents/dynamic_pricing/MTS.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def draw_action(self, observation: np.ndarray) -> np.ndarray:
147147
price = self.price_function(x_feat, alpha, beta)
148148
for proc in self.actionprocessors:
149149
price = proc(price)
150-
return np.asarray([price], dtype=float) # keep shape (1,)
150+
return price # keep shape (1,)
151151

152152
# --------------------------------------------------
153153
# Online update after receiving (x, price, demand)
@@ -157,7 +157,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray, action: float):
157157
self.t += 1
158158
m = np.concatenate([X, X * action]).astype(float) # (2d,)
159159
self.X_buf = np.vstack([self.X_buf, m])
160-
self.Y_buf = np.vstack([self.Y_buf, [[Y]]])
160+
self.Y_buf = np.vstack([self.Y_buf, [Y]])
161161

162162
# update posterior *after* burn‑in
163163
if self.t >= self.t_e:
@@ -166,7 +166,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray, action: float):
166166
# --------------------------------------------------
167167
# End of epoch – build OLS & possibly refresh meta‑prior
168168
# --------------------------------------------------
169-
def update_env(self, env):
169+
def update_task(self, env):
170170
"""Call this after each product/epoch ends."""
171171
# ---------- compute OLS (full‑rank guaranteed by burn‑in) ----------
172172
V = self.X_buf.T @ self.X_buf # (2d,2d)
@@ -254,8 +254,8 @@ def fit(self, dataset, **kwargs):
254254
Y = kwargs["demand"][0]
255255
action = dataset[0][1]
256256
self.policy.fit(X, Y, action)
257-
def update_env(self, env):
258-
self.policy.update_env(env)
257+
def update_task(self, env):
258+
self.policy.update_task(env)
259259

260260

261261
# %% ../../../nbs/30_agents/42_DP_agents/12_MTS_agent.ipynb 7
@@ -284,7 +284,7 @@ def __init__(self,
284284
agent_name=agent_name, ex_prices=ex_prices, price_function=price_function, g=g)
285285
super().__init__(environment_info=environment_info, obsprocessors=obsprocessors, agent_name=agent_name)
286286

287-
def update_env(self, env: object):
287+
def update_task(self, env: object):
288288
""" Update the environment specific parameters of the agent """
289-
self.agent.update_env(env)
289+
self.agent.update_task(env)
290290

ddopai/agents/dynamic_pricing/TS.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def parameter_update(self):
119119
self.alpha = results.params[:self.environment_info.observation_space['features'].shape[0]]
120120
self.beta = results.params[self.environment_info.observation_space['features'].shape[0]:]
121121

122-
def update_env(self, env):
122+
def update_task(self, env):
123123
self.environment_info = env.mdp_info
124124
self.X = np.empty((0, self.environment_info.observation_space['features'].shape[0] * 2))
125125
self.Y = np.empty((0, 1))
@@ -159,8 +159,8 @@ def fit(self, dataset, **kwargs):
159159
Y = kwargs["demand"][0]
160160
action = dataset[0][1]
161161
self.policy.fit(X, Y, action)
162-
def update_env(self, env):
163-
self.policy.update_env(env)
162+
def update_task(self, env):
163+
self.policy.update_task(env)
164164

165165
# %% ../../../nbs/30_agents/42_DP_agents/12_TS_agent.ipynb 6
166166
class TSAgent(PricingMushroomBaseAgent):
@@ -190,6 +190,6 @@ def __init__(self,
190190
price_function=price_function,
191191
g=g)
192192
super().__init__(environment_info=environment_info, obsprocessors=obsprocessors, agent_name=agent_name)
193-
def update_env(self, env: object):
193+
def update_task(self, env: object):
194194
""" Update the environment specific parameters of the agent """
195-
self.agent.update_env(env)
195+
self.agent.update_task(env)

ddopai/agents/dynamic_pricing/UCB.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def parameter_update(self):
119119
self.alpha = results.params[:self.environment_info.observation_space['features'].shape[0]]
120120
self.beta = results.params[self.environment_info.observation_space['features'].shape[0]:]
121121

122-
def update_env(self, env):
122+
def update_task(self, env):
123123
self.environment_info = env.mdp_info
124124
self.X = np.empty((0, self.environment_info.observation_space['features'].shape[0] * 2))
125125
self.Y = np.empty((0, 1))
@@ -162,8 +162,8 @@ def fit(self, dataset, **kwargs):
162162
action = dataset[0][1]
163163
self.policy.fit(X, Y, action)
164164

165-
def update_env(self, env):
166-
self.policy.update_env(env)
165+
def update_task(self, env):
166+
self.policy.update_task(env)
167167

168168
# %% ../../../nbs/30_agents/42_DP_agents/13_UCB_agent.ipynb 6
169169
class UCBAgent(PricingMushroomBaseAgent):
@@ -193,6 +193,6 @@ def __init__(self,
193193
price_function=price_function,
194194
g=g)
195195
super().__init__(environment_info=environment_info, obsprocessors=obsprocessors, agent_name=agent_name)
196-
def update_env(self, env: object):
196+
def update_task(self, env: object):
197197
""" Update the environment specific parameters of the agent """
198-
self.agent.update_env(env)
198+
self.agent.update_task(env)

ddopai/agents/rl/hyper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def draw_action_(self, obs: np.ndarray) -> np.ndarray: # DDOP naming
118118
latent_logvar=self.latent_logvar)
119119

120120
with torch.no_grad():
121-
_, action, _ = self.policy.act(state=state_t.view(-1),
121+
_, action, _ = self.policy.act(state=state_t,
122122
latent=latent,
123123
belief=None, task=None,
124124
deterministic=self.deterministic)

ddopai/agents/rl/sac.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,27 @@ def __init__(self,
129129
dropout=self.dropout,)
130130

131131
critic_params = merge_dictionaries(critic_params, network_critic_params)
132-
132+
self.agent_params = {
133+
"mdp_info": environment_info,
134+
"actor_mu_params": actor_mu_params,
135+
"actor_sigma_params": actor_sigma_params,
136+
"actor_optimizer": actor_optimizer,
137+
"critic_params": critic_params,
138+
"batch_size": batch_size,
139+
"initial_replay_size": initial_replay_size,
140+
"max_replay_size": max_replay_size,
141+
"warmup_transitions": warmup_transitions,
142+
"tau": tau,
143+
"lr_alpha": lr_alpha,
144+
"use_log_alpha_loss": use_log_alpha_loss,
145+
"log_std_min": log_std_min,
146+
"log_std_max": log_std_max,
147+
"target_entropy": target_entropy,
148+
"critic_fit_params": None
149+
}
150+
self._obsprocessors = obsprocessors
151+
self.device = device
152+
self.agent_name = agent_name
133153
self.agent = SAC(
134154
mdp_info=environment_info,
135155
actor_mu_params=actor_mu_params,
@@ -228,6 +248,33 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: #
228248
action = action.cpu().detach().numpy()
229249

230250
return action
251+
252+
def update_task(self, env):
253+
self.agent = SAC(
254+
mdp_info=env.mdp_info,
255+
actor_mu_params=self.agent_params["actor_mu_params"],
256+
actor_sigma_params=self.agent_params["actor_sigma_params"],
257+
actor_optimizer=self.agent_params["actor_optimizer"],
258+
critic_params=self.agent_params["critic_params"],
259+
batch_size=self.agent_params["batch_size"],
260+
initial_replay_size=self.agent_params["initial_replay_size"],
261+
max_replay_size=self.agent_params["max_replay_size"],
262+
warmup_transitions=self.agent_params["warmup_transitions"],
263+
tau=self.agent_params["tau"],
264+
lr_alpha=self.agent_params["lr_alpha"],
265+
use_log_alpha_loss=self.agent_params["use_log_alpha_loss"],
266+
log_std_min=self.agent_params["log_std_min"],
267+
log_std_max=self.agent_params["log_std_max"],
268+
target_entropy=self.agent_params["target_entropy"],
269+
critic_fit_params=self.agent_params["critic_fit_params"]
270+
)
271+
self.obsprocessors = self._obsprocessors
272+
super().__init__(
273+
environment_info=env.mdp_info,
274+
obsprocessors=self._obsprocessors,
275+
device=self.device,
276+
agent_name=self.agent_name
277+
)
231278

232279
# %% ../../../nbs/30_agents/51_RL_agents/10_SAC_agents.ipynb 6
233280
class SACAgent(SACBaseAgent):

ddopai/experiments/meta_experiment_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def create_online_data(
223223
size = parameter["horizon"]
224224
noise_std = parameter["noise_std"]
225225
if nb_features > 1:
226-
scale = 1 / np.sqrt(nb_features-1)
226+
scale = 1 / np.sqrt(nb_features)
227227
X = np.random.uniform(0, scale, size=(size, nb_features))
228228
else:
229229
X = np.ones((size, 1))

nbs/30_agents/42_DP_agents/11_ILQX_agent.ipynb

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,10 @@
123123
" self.alpha = results.params[:self.environment_info.observation_space['features'].shape[0]]\n",
124124
" self.beta = results.params[self.environment_info.observation_space['features'].shape[0]:]\n",
125125
" \n",
126-
" def update_env(self, env):\n",
126+
" def update_task(self, env):\n",
127127
" self.environment_info = env.mdp_info\n",
128128
" self.X = np.empty((0, self.environment_info.observation_space['features'].shape[0] * 2))\n",
129129
" self.Y = np.empty((0, 1))\n",
130-
" self.actionprocessors[-1] = ClipAction(self.environment_info.action_space.low, self.environment_info.action_space.high)\n",
131130
" self.M = [[np.power(x,2)+i for x in range(0, int(np.sqrt(self.environment_info.horizon)))] for i in range(0, 2)]\n",
132131
" self.t = 0 \n",
133132
" \n",
@@ -171,8 +170,8 @@
171170
" action = dataset[0][1]\n",
172171
" self.policy.fit(X, Y, action)\n",
173172
"\n",
174-
" def update_env(self, env):\n",
175-
" self.policy.update_env(env)\n"
173+
" def update_task(self, env):\n",
174+
" self.policy.update_task(env)\n"
176175
]
177176
},
178177
{
@@ -207,9 +206,9 @@
207206
" price_function = price_function, \n",
208207
" g = g)\n",
209208
" super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)\n",
210-
" def update_env(self, env: object):\n",
209+
" def update_task(self, env: object):\n",
211210
" \"\"\" Update the environment specific parameters of the agent \"\"\"\n",
212-
" self.agent.update_env(env)"
211+
" self.agent.update_task(env)"
213212
]
214213
}
215214
],

nbs/30_agents/42_DP_agents/12_MTS_agent.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
" price = self.price_function(x_feat, alpha, beta)\n",
190190
" for proc in self.actionprocessors:\n",
191191
" price = proc(price)\n",
192-
" return np.asarray([price], dtype=float) # keep shape (1,)\n",
192+
" return price # keep shape (1,)\n",
193193
" \n",
194194
" # --------------------------------------------------\n",
195195
" # Online update after receiving (x, price, demand)\n",
@@ -199,7 +199,7 @@
199199
" self.t += 1\n",
200200
" m = np.concatenate([X, X * action]).astype(float) # (2d,)\n",
201201
" self.X_buf = np.vstack([self.X_buf, m])\n",
202-
" self.Y_buf = np.vstack([self.Y_buf, [[Y]]])\n",
202+
" self.Y_buf = np.vstack([self.Y_buf, [Y]])\n",
203203
"\n",
204204
" # update posterior *after* burn‑in\n",
205205
" if self.t >= self.t_e:\n",
@@ -208,7 +208,7 @@
208208
" # --------------------------------------------------\n",
209209
" # End of epoch – build OLS & possibly refresh meta‑prior\n",
210210
" # --------------------------------------------------\n",
211-
" def update_env(self, env):\n",
211+
" def update_task(self, env):\n",
212212
" \"\"\"Call this after each product/epoch ends.\"\"\"\n",
213213
" # ---------- compute OLS (full‑rank guaranteed by burn‑in) ----------\n",
214214
" V = self.X_buf.T @ self.X_buf # (2d,2d)\n",
@@ -302,8 +302,8 @@
302302
" Y = kwargs[\"demand\"][0]\n",
303303
" action = dataset[0][1]\n",
304304
" self.policy.fit(X, Y, action)\n",
305-
" def update_env(self, env):\n",
306-
" self.policy.update_env(env)\n"
305+
" def update_task(self, env):\n",
306+
" self.policy.update_task(env)\n"
307307
]
308308
},
309309
{
@@ -338,9 +338,9 @@
338338
" agent_name=agent_name, ex_prices=ex_prices, price_function=price_function, g=g)\n",
339339
" super().__init__(environment_info=environment_info, obsprocessors=obsprocessors, agent_name=agent_name)\n",
340340
" \n",
341-
" def update_env(self, env: object):\n",
341+
" def update_task(self, env: object):\n",
342342
" \"\"\" Update the environment specific parameters of the agent \"\"\"\n",
343-
" self.agent.update_env(env)\n"
343+
" self.agent.update_task(env)\n"
344344
]
345345
}
346346
],

0 commit comments

Comments
 (0)