Skip to content

Commit 9991dc2

Browse files
committed
Fixed horizon call in visualise env
1 parent 3406fff commit 9991dc2

File tree

3 files changed

+224
-234
lines changed

3 files changed

+224
-234
lines changed

ddopai/_modidx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,8 +1744,8 @@
17441744
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
17451745
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.step': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.step',
17461746
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
1747-
'ddopai.meta_learning.environments.pricing_env.pricing_env.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#visualise_behaviour',
1748-
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
1747+
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.visualise_behaviour',
1748+
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
17491749
'ddopai.meta_learning.environments.wrappers': { 'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper',
17501750
'ddopai/meta_learning/environments/wrappers.py'),
17511751
'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper.__init__': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper.__init__',

ddopai/meta_learning/environments/pricing_env/pricing_env.py

Lines changed: 114 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb.
22

33
# %% auto 0
4-
__all__ = ['PricingEnv', 'visualise_behaviour']
4+
__all__ = ['device', 'PricingEnv']
55

66
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 1
77
import gym
@@ -11,6 +11,7 @@
1111
import matplotlib.pyplot as plt
1212
import torch
1313
from ...utils import helpers as utl
14+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1415

1516
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 2
1617
class PricingEnv(gym.Env):
@@ -245,135 +246,129 @@ def _demand(self, price: float, noise: float) -> float:
245246
return max(0.0, mean + noise)
246247

247248
# ---------- visualisation stub -------------------------------------------
248-
@staticmethod
249-
def visualise_behaviour(env,
250-
args,
251-
policy,
252-
iter_idx,
253-
encoder=None,
254-
image_folder=None,
255-
return_pos=False,
256-
**kwargs):
257-
258-
num_episodes = args.max_rollouts_per_task
259-
260-
episode_prev_obs = [[] for _ in range(num_episodes)]
261-
episode_next_obs = [[] for _ in range(num_episodes)]
262-
episode_actions = [[] for _ in range(num_episodes)] # price = action
263-
episode_rewards = [[] for _ in range(num_episodes)]
264-
episode_returns = []
265-
266-
if encoder is not None:
267-
episode_latent_samples = [[] for _ in range(num_episodes)]
268-
episode_latent_means = [[] for _ in range(num_episodes)]
269-
episode_latent_logvars = [[] for _ in range(num_episodes)]
270-
else:
271-
episode_latent_samples = episode_latent_means = episode_latent_logvars = None
272-
273-
env.reset_task()
274-
state, belief, task = utl.reset_env(env, args)
275-
task = task.view(-1) if task is not None else None
276-
277-
hidden_state = torch.zeros((1, args.hidden_size)).to(args.device) if hasattr(args, 'hidden_size') else None
278-
279-
for episode_idx in range(num_episodes):
280-
curr_rollout_rew = []
281-
282-
if episode_idx == 0:
283-
if encoder is not None:
284-
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
285-
curr_latent_sample = curr_latent_sample[0].to(args.device)
286-
curr_latent_mean = curr_latent_mean[0].to(args.device)
287-
curr_latent_logvar = curr_latent_logvar[0].to(args.device)
288-
else:
289-
curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
249+
@staticmethod
250+
def visualise_behaviour(env,
251+
args,
252+
policy,
253+
iter_idx,
254+
encoder=None,
255+
image_folder=None,
256+
**kwargs):
257+
258+
num_episodes = args.max_rollouts_per_task
259+
260+
episode_prev_obs = [[] for _ in range(num_episodes)]
261+
episode_next_obs = [[] for _ in range(num_episodes)]
262+
episode_actions = [[] for _ in range(num_episodes)] # price = action
263+
episode_rewards = [[] for _ in range(num_episodes)]
264+
episode_returns = []
290265

291266
if encoder is not None:
292-
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
293-
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
294-
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
267+
episode_latent_samples = [[] for _ in range(num_episodes)]
268+
episode_latent_means = [[] for _ in range(num_episodes)]
269+
episode_latent_logvars = [[] for _ in range(num_episodes)]
270+
else:
271+
episode_latent_samples = episode_latent_means = episode_latent_logvars = None
272+
273+
env.reset_task()
274+
state, belief, task = utl.reset_env(env, args)
275+
task = task.view(-1) if task is not None else None
276+
277+
hidden_state = torch.zeros((1, args.hidden_size)).to(device) if hasattr(args, 'hidden_size') else None
278+
279+
for episode_idx in range(num_episodes):
280+
curr_rollout_rew = []
281+
282+
if episode_idx == 0:
283+
if encoder is not None:
284+
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
285+
curr_latent_sample = curr_latent_sample[0].to(device)
286+
curr_latent_mean = curr_latent_mean[0].to(device)
287+
curr_latent_logvar = curr_latent_logvar[0].to(device)
288+
else:
289+
curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
295290

296-
obs = env.reset()
291+
if encoder is not None:
292+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
293+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
294+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
297295

298-
for step_idx in range(1, env.horizon + 1):
299-
prev_obs = torch.tensor(obs, dtype=torch.float32).to(args.device).unsqueeze(0)
300-
episode_prev_obs[episode_idx].append(prev_obs.clone())
296+
obs = env.reset()
301297

302-
latent = utl.get_latent_for_policy(args,
303-
latent_sample=curr_latent_sample,
304-
latent_mean=curr_latent_mean,
305-
latent_logvar=curr_latent_logvar)
298+
for step_idx in range(1, env._max_episode_steps + 1):
299+
prev_obs = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0)
300+
episode_prev_obs[episode_idx].append(prev_obs.clone())
306301

307-
_, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)
302+
latent = utl.get_latent_for_policy(args,
303+
latent_sample=curr_latent_sample,
304+
latent_mean=curr_latent_mean,
305+
latent_logvar=curr_latent_logvar)
308306

309-
obs, reward, done, info = env.step(action.cpu().numpy())
310-
obs = torch.tensor(obs, dtype=torch.float32).to(args.device).unsqueeze(0)
307+
_, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)
311308

312-
episode_next_obs[episode_idx].append(obs.clone())
313-
episode_actions[episode_idx].append(action.clone())
314-
episode_rewards[episode_idx].append(torch.tensor([reward], dtype=torch.float32).to(args.device))
315-
curr_rollout_rew.append(reward)
309+
obs, reward, done, info = env.step(action.cpu().numpy())
310+
obs = torch.tensor(obs, dtype=torch.float32).to(device).unsqueeze(0)
316311

317-
if encoder is not None:
318-
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
319-
action.reshape(1, -1).float().to(args.device),
320-
obs,
321-
torch.tensor([reward], dtype=torch.float32, device=args.device).reshape(1, -1),
322-
prev_obs,
323-
hidden_state,
324-
return_prior=False,
325-
)
326-
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
327-
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
328-
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
312+
episode_next_obs[episode_idx].append(obs.clone())
313+
episode_actions[episode_idx].append(action.clone())
314+
episode_rewards[episode_idx].append(torch.tensor([reward], dtype=torch.float32).to(device))
315+
curr_rollout_rew.append(reward)
329316

330-
if done:
331-
break
332-
333-
episode_returns.append(sum(curr_rollout_rew))
334-
335-
# Convert to tensor batches
336-
if encoder is not None:
337-
episode_latent_means = [torch.stack(e) for e in episode_latent_means]
338-
episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]
339-
340-
episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
341-
episode_next_obs = [torch.cat(e) for e in episode_next_obs]
342-
episode_actions = [torch.stack(e) for e in episode_actions]
343-
episode_rewards = [torch.cat(e) for e in episode_rewards]
344-
345-
# ---- Plot: Price (action) and Revenue ----
346-
import matplotlib.pyplot as plt
347-
348-
plt.figure(figsize=(10, 3 * num_episodes))
349-
for i in range(num_episodes):
350-
plt.subplot(num_episodes, 2, 2 * i + 1)
351-
plt.plot(episode_actions[i].cpu().numpy(), label="Price")
352-
plt.ylabel("Price")
353-
plt.xlabel("Timestep")
354-
plt.title(f"Episode {i}: Price")
355-
356-
plt.subplot(num_episodes, 2, 2 * i + 2)
357-
plt.plot(episode_rewards[i].cpu().numpy(), label="Revenue", color='green')
358-
plt.ylabel("Revenue")
359-
plt.xlabel("Timestep")
360-
plt.title(f"Episode {i}: Revenue")
361-
362-
plt.tight_layout()
363-
if image_folder is not None:
364-
plt.savefig(f"{image_folder}/{iter_idx}_pricing_behaviour.png")
365-
plt.close()
366-
else:
367-
plt.show()
368-
369-
if not return_pos:
370-
return episode_latent_means, episode_latent_logvars, \
371-
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
372-
episode_returns
373-
else:
317+
if encoder is not None:
318+
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
319+
action.reshape(1, -1).float().to(device),
320+
obs,
321+
torch.tensor([reward], dtype=torch.float32, device=device).reshape(1, -1),
322+
prev_obs,
323+
hidden_state,
324+
return_prior=False,
325+
)
326+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
327+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
328+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
329+
330+
if done:
331+
break
332+
333+
episode_returns.append(sum(curr_rollout_rew))
334+
335+
# Convert to tensor batches
336+
if encoder is not None:
337+
episode_latent_means = [torch.stack(e) for e in episode_latent_means]
338+
episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]
339+
340+
episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
341+
episode_next_obs = [torch.cat(e) for e in episode_next_obs]
342+
episode_actions = [torch.stack(e) for e in episode_actions]
343+
episode_rewards = [torch.cat(e) for e in episode_rewards]
344+
345+
346+
plt.figure(figsize=(10, 3 * num_episodes))
347+
for i in range(num_episodes):
348+
plt.subplot(num_episodes, 2, 2 * i + 1)
349+
plt.plot(episode_actions[i].cpu().numpy(), label="Price")
350+
plt.ylabel("Price")
351+
plt.xlabel("Timestep")
352+
plt.title(f"Episode {i}: Price")
353+
354+
plt.subplot(num_episodes, 2, 2 * i + 2)
355+
plt.plot(episode_rewards[i].cpu().numpy(), label="Revenue", color='green')
356+
plt.ylabel("Revenue")
357+
plt.xlabel("Timestep")
358+
plt.title(f"Episode {i}: Revenue")
359+
360+
plt.tight_layout()
361+
if image_folder is not None:
362+
plt.savefig(f"{image_folder}/{iter_idx}_pricing_behaviour.png")
363+
plt.close()
364+
else:
365+
plt.show()
366+
367+
374368
return episode_latent_means, episode_latent_logvars, \
375-
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
376-
episode_returns, episode_actions # actions = price = pos
369+
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
370+
episode_returns
371+
377372

378373

379374

0 commit comments

Comments
 (0)