Skip to content

Commit 3406fff

Browse files
committed
added vis to the pricing env
1 parent adfe25d commit 3406fff

File tree

5 files changed

+453
-1075
lines changed

5 files changed

+453
-1075
lines changed

ddopai/_modidx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,8 +1744,8 @@
17441744
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
17451745
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.step': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.step',
17461746
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
1747-
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.visualise_behaviour',
1748-
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
1747+
'ddopai.meta_learning.environments.pricing_env.pricing_env.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#visualise_behaviour',
1748+
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
17491749
'ddopai.meta_learning.environments.wrappers': { 'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper',
17501750
'ddopai/meta_learning/environments/wrappers.py'),
17511751
'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper.__init__': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper.__init__',

ddopai/meta_learning/environments/pricing_env/pricing_env.py

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb.
22

33
# %% auto 0
4-
__all__ = ['PricingEnv']
4+
__all__ = ['PricingEnv', 'visualise_behaviour']
55

66
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 1
77
import gym
88
from abc import ABC, abstractmethod
99
from typing import Union, List, Dict, Optional
1010
import numpy as np
11-
11+
import matplotlib.pyplot as plt
12+
import torch
13+
from ...utils import helpers as utl
1214

1315
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 2
1416
class PricingEnv(gym.Env):
@@ -187,6 +189,7 @@ def step(self, action):
187189
obs = self._get_obs()
188190
info = {
189191
"task": self.get_task(),
192+
"action": price,
190193
"noise": noise,
191194
"demand": demand,
192195
"sales": sales,
@@ -242,9 +245,135 @@ def _demand(self, price: float, noise: float) -> float:
242245
return max(0.0, mean + noise)
243246

244247
# ---------- visualisation stub -------------------------------------------
245-
def visualise_behaviour(self, *_, **__):
246-
"""
247-
Optional — leave blank. Hyper’s default visualiser is used if None.
248-
"""
249-
return None, None, None, None, None, None, None
248+
@staticmethod
249+
def visualise_behaviour(env,
250+
args,
251+
policy,
252+
iter_idx,
253+
encoder=None,
254+
image_folder=None,
255+
return_pos=False,
256+
**kwargs):
257+
258+
num_episodes = args.max_rollouts_per_task
259+
260+
episode_prev_obs = [[] for _ in range(num_episodes)]
261+
episode_next_obs = [[] for _ in range(num_episodes)]
262+
episode_actions = [[] for _ in range(num_episodes)] # price = action
263+
episode_rewards = [[] for _ in range(num_episodes)]
264+
episode_returns = []
265+
266+
if encoder is not None:
267+
episode_latent_samples = [[] for _ in range(num_episodes)]
268+
episode_latent_means = [[] for _ in range(num_episodes)]
269+
episode_latent_logvars = [[] for _ in range(num_episodes)]
270+
else:
271+
episode_latent_samples = episode_latent_means = episode_latent_logvars = None
272+
273+
env.reset_task()
274+
state, belief, task = utl.reset_env(env, args)
275+
task = task.view(-1) if task is not None else None
276+
277+
hidden_state = torch.zeros((1, args.hidden_size)).to(args.device) if hasattr(args, 'hidden_size') else None
278+
279+
for episode_idx in range(num_episodes):
280+
curr_rollout_rew = []
281+
282+
if episode_idx == 0:
283+
if encoder is not None:
284+
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
285+
curr_latent_sample = curr_latent_sample[0].to(args.device)
286+
curr_latent_mean = curr_latent_mean[0].to(args.device)
287+
curr_latent_logvar = curr_latent_logvar[0].to(args.device)
288+
else:
289+
curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
290+
291+
if encoder is not None:
292+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
293+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
294+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
295+
296+
obs = env.reset()
297+
298+
for step_idx in range(1, env.horizon + 1):
299+
prev_obs = torch.tensor(obs, dtype=torch.float32).to(args.device).unsqueeze(0)
300+
episode_prev_obs[episode_idx].append(prev_obs.clone())
301+
302+
latent = utl.get_latent_for_policy(args,
303+
latent_sample=curr_latent_sample,
304+
latent_mean=curr_latent_mean,
305+
latent_logvar=curr_latent_logvar)
306+
307+
_, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)
308+
309+
obs, reward, done, info = env.step(action.cpu().numpy())
310+
obs = torch.tensor(obs, dtype=torch.float32).to(args.device).unsqueeze(0)
311+
312+
episode_next_obs[episode_idx].append(obs.clone())
313+
episode_actions[episode_idx].append(action.clone())
314+
episode_rewards[episode_idx].append(torch.tensor([reward], dtype=torch.float32).to(args.device))
315+
curr_rollout_rew.append(reward)
316+
317+
if encoder is not None:
318+
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
319+
action.reshape(1, -1).float().to(args.device),
320+
obs,
321+
torch.tensor([reward], dtype=torch.float32, device=args.device).reshape(1, -1),
322+
prev_obs,
323+
hidden_state,
324+
return_prior=False,
325+
)
326+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
327+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
328+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
329+
330+
if done:
331+
break
332+
333+
episode_returns.append(sum(curr_rollout_rew))
334+
335+
# Convert to tensor batches
336+
if encoder is not None:
337+
episode_latent_means = [torch.stack(e) for e in episode_latent_means]
338+
episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]
339+
340+
episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
341+
episode_next_obs = [torch.cat(e) for e in episode_next_obs]
342+
episode_actions = [torch.stack(e) for e in episode_actions]
343+
episode_rewards = [torch.cat(e) for e in episode_rewards]
344+
345+
# ---- Plot: Price (action) and Revenue ----
346+
import matplotlib.pyplot as plt
347+
348+
plt.figure(figsize=(10, 3 * num_episodes))
349+
for i in range(num_episodes):
350+
plt.subplot(num_episodes, 2, 2 * i + 1)
351+
plt.plot(episode_actions[i].cpu().numpy(), label="Price")
352+
plt.ylabel("Price")
353+
plt.xlabel("Timestep")
354+
plt.title(f"Episode {i}: Price")
355+
356+
plt.subplot(num_episodes, 2, 2 * i + 2)
357+
plt.plot(episode_rewards[i].cpu().numpy(), label="Revenue", color='green')
358+
plt.ylabel("Revenue")
359+
plt.xlabel("Timestep")
360+
plt.title(f"Episode {i}: Revenue")
361+
362+
plt.tight_layout()
363+
if image_folder is not None:
364+
plt.savefig(f"{image_folder}/{iter_idx}_pricing_behaviour.png")
365+
plt.close()
366+
else:
367+
plt.show()
368+
369+
if not return_pos:
370+
return episode_latent_means, episode_latent_logvars, \
371+
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
372+
episode_returns
373+
else:
374+
return episode_latent_means, episode_latent_logvars, \
375+
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
376+
episode_returns, episode_actions # actions = price = pos
377+
378+
250379

ddopai/meta_learning/utils/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
from torch.nn import functional as F
21-
from ..environments.pricing_env.pricing_env import PricingEnv
21+
2222
from ..environments.wrappers import PrevActRewWrapper
2323
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2424

@@ -39,7 +39,7 @@ def make_env(args, mode='train', **kwargs):
3939
"""
4040
assert args.env_name.lower().startswith('pricing'), \
4141
"This trimmed helper only supports PricingEnv."
42-
42+
from ddopai.meta_learning.environments.pricing_env.pricing_env import PricingEnv
4343
# base env --------------------------------------------------------
4444
env = PricingEnv(**args.pricing_kwargs)
4545

nbs/50_meta_learning/50_utils/20_helpers.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"import torch\n",
2929
"import torch.nn as nn\n",
3030
"from torch.nn import functional as F\n",
31-
"from ddopai.meta_learning.environments.pricing_env.pricing_env import PricingEnv\n",
31+
"\n",
3232
"from ddopai.meta_learning.environments.wrappers import PrevActRewWrapper\n",
3333
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
3434
]
@@ -57,7 +57,7 @@
5757
" \"\"\"\n",
5858
" assert args.env_name.lower().startswith('pricing'), \\\n",
5959
" \"This trimmed helper only supports PricingEnv.\"\n",
60-
"\n",
60+
" from ddopai.meta_learning.environments.pricing_env.pricing_env import PricingEnv\n",
6161
" # base env --------------------------------------------------------\n",
6262
" env = PricingEnv(**args.pricing_kwargs)\n",
6363
"\n",

0 commit comments

Comments
 (0)