Skip to content

Commit bf757a7

Browse files
committed
moved plot outside and copied the get test rollout code which worked
1 parent 3182936 commit bf757a7

File tree

3 files changed

+150
-166
lines changed

3 files changed

+150
-166
lines changed

ddopai/_modidx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,9 @@
17451745
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.step': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.step',
17461746
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
17471747
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.visualise_behaviour',
1748-
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
1748+
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
1749+
'ddopai.meta_learning.environments.pricing_env.pricing_env.plot_actions_reward': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#plot_actions_reward',
1750+
'ddopai/meta_learning/environments/pricing_env/pricing_env.py')},
17491751
'ddopai.meta_learning.environments.wrappers': { 'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper',
17501752
'ddopai/meta_learning/environments/wrappers.py'),
17511753
'ddopai.meta_learning.environments.wrappers.PrevActRewWrapper.__init__': ( '50_meta_learning/53_environments/wrappers.html#prevactrewwrapper.__init__',

ddopai/meta_learning/environments/pricing_env/pricing_env.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb.
22

33
# %% auto 0
4-
__all__ = ['device', 'PricingEnv']
4+
__all__ = ['device', 'PricingEnv', 'plot_actions_reward']
55

66
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 1
77
import gym
@@ -261,153 +261,141 @@ def visualise_behaviour(env,
261261
"""
262262

263263
num_episodes = args.max_rollouts_per_task
264-
unwrapped_env = env.venv.unwrapped.envs[0]
265264

266-
episode_all_obs = [[] for _ in range(num_episodes)]
265+
# --- initialise things we want to keep track of ---
266+
267267
episode_prev_obs = [[] for _ in range(num_episodes)]
268268
episode_next_obs = [[] for _ in range(num_episodes)]
269269
episode_actions = [[] for _ in range(num_episodes)]
270270
episode_rewards = [[] for _ in range(num_episodes)]
271+
271272
episode_returns = []
272273
episode_lengths = []
273274

274-
if args.pass_belief_to_policy and (encoder is None):
275-
episode_beliefs = [[] for _ in range(num_episodes)]
276-
else:
277-
episode_beliefs = None
278-
279275
if encoder is not None:
280276
episode_latent_samples = [[] for _ in range(num_episodes)]
281277
episode_latent_means = [[] for _ in range(num_episodes)]
282278
episode_latent_logvars = [[] for _ in range(num_episodes)]
283279
else:
284-
episode_latent_samples = episode_latent_means = episode_latent_logvars = None
280+
curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
281+
episode_latent_means = episode_latent_logvars = None
282+
283+
# --- roll out policy ---
285284

285+
# (re)set environment
286286
env.reset_task()
287-
[state, belief, task] = utl.reset_env(env, args)
288-
start_obs = state.clone()
287+
state, belief, task = utl.reset_env(env, args)
288+
state = state.reshape((1, -1)).to(device)
289+
task = task.view(-1) if task is not None else None
289290

290291
for episode_idx in range(num_episodes):
291292

292293
curr_rollout_rew = []
293294

294-
295-
296295
if encoder is not None:
297-
298-
if episode_idx == 0 and encoder is not None:
296+
if episode_idx == 0:
299297
# reset to prior
300298
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
301299
curr_latent_sample = curr_latent_sample[0].to(device)
302300
curr_latent_mean = curr_latent_mean[0].to(device)
303301
curr_latent_logvar = curr_latent_logvar[0].to(device)
304-
305302
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
306303
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
307304
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
308305

309-
episode_all_obs[episode_idx].append(start_obs.clone())
310-
if args.pass_belief_to_policy and (encoder is None):
311-
episode_beliefs[episode_idx].append(belief)
312-
313-
for step_idx in range(env._max_episode_steps):
314-
315-
if step_idx == 1:
316-
prev_obs = start_obs.clone()
317-
else:
318-
prev_obs = state.clone()
319-
320-
episode_prev_obs[episode_idx].append(prev_obs)
321-
322-
# act
323-
_, action, _ = utl.select_action(args=args,
324-
policy=policy,
325-
state=state.view(-1),
326-
belief=belief,
327-
task=task,
328-
deterministic=True,
329-
latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None,
330-
latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None,
331-
latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None,
332-
)
333-
306+
for step_idx in range(1, env._max_episode_steps + 1):
307+
308+
episode_prev_obs[episode_idx].append(state.clone())
309+
prev_state = state.clone()
310+
311+
latent = utl.get_latent_for_policy(args,
312+
latent_sample=curr_latent_sample,
313+
latent_mean=curr_latent_mean,
314+
latent_logvar=curr_latent_logvar)
315+
_, action, _ = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True)
316+
action = action.reshape((1, *action.shape))
317+
334318
# observe reward and next obs
335-
[state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)
319+
(state, belief, task), (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)
320+
state = state.reshape((1, -1)).to(device)
321+
task = task.view(-1) if task is not None else None
336322

337323
if encoder is not None:
338324
# update task embedding
339325
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
340326
action.float().to(device),
341327
state,
342328
rew_raw.reshape((1, 1)).float().to(device),
343-
prev_obs,
329+
prev_state,
344330
hidden_state,
345331
return_prior=False)
346332

347333
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
348334
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
349335
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
350336

351-
episode_all_obs[episode_idx].append(state.clone())
352337
episode_next_obs[episode_idx].append(state.clone())
353338
episode_rewards[episode_idx].append(rew_raw.clone())
354339
episode_actions[episode_idx].append(action.clone())
355340

356-
curr_rollout_rew.append(rew_raw.clone())
357-
358-
359-
if args.pass_belief_to_policy and (encoder is None):
360-
episode_beliefs[episode_idx].append(belief)
361-
362-
if infos[0]['done_mdp'] and not done:
363-
start_obs = infos[0]['start_state']
364-
start_obs = torch.from_numpy(start_obs).float().reshape((1, -1)).to(device)
341+
if infos[0]['done_mdp']:
365342
break
366343

367344
episode_returns.append(sum(curr_rollout_rew))
368345
episode_lengths.append(step_idx)
369346

370-
371347
# clean up
372-
373348
if encoder is not None:
374349
episode_latent_means = [torch.stack(e) for e in episode_latent_means]
375350
episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]
376351

377352
episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
378353
episode_next_obs = [torch.cat(e) for e in episode_next_obs]
379354
episode_actions = [torch.cat(e) for e in episode_actions]
380-
episode_rewards = [torch.cat(e) for e in episode_rewards]
381-
382-
383-
# Plot price and reward trajectories
384-
import matplotlib.pyplot as plt
385-
plt.figure(figsize=(10, 3 * num_episodes))
386-
for i in range(num_episodes):
387-
plt.subplot(num_episodes, 2, 2 * i + 1)
388-
plt.plot(episode_actions[i].cpu().numpy(), label="Price")
389-
plt.ylabel("Price")
390-
plt.xlabel("Timestep")
391-
plt.title(f"Episode {i}: Price")
392-
393-
plt.subplot(num_episodes, 2, 2 * i + 2)
394-
plt.plot(episode_rewards[i].cpu().numpy(), label="Revenue", color='green')
395-
plt.ylabel("Revenue")
396-
plt.xlabel("Timestep")
397-
plt.title(f"Episode {i}: Revenue")
398-
399-
plt.tight_layout()
400-
if image_folder is not None:
401-
plt.savefig(f"{image_folder}/{iter_idx}_pricing_behaviour.png")
402-
plt.close()
403-
else:
404-
plt.show()
405-
355+
episode_rewards = [torch.cat(r) for r in episode_rewards]
356+
357+
plot_actions_reward(
358+
episode_actions=episode_actions,
359+
episode_rewards=episode_rewards,
360+
episode_lengths=episode_lengths,
361+
image_folder=image_folder,
362+
iter_idx=iter_idx
363+
)
406364
return episode_latent_means, episode_latent_logvars, \
407-
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
408-
episode_returns
365+
episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
366+
episode_returns
367+
368+
369+
409370

410371

411372

412373

413374

375+
376+
# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 3
377+
def plot_actions_reward(
378+
episode_actions: List[torch.Tensor],
379+
episode_rewards: List[torch.Tensor],
380+
episode_lengths: List[int],
381+
image_folder: Optional[str] = None,
382+
iter_idx: int = 0
383+
):
384+
"""
385+
Plot actions and rewards for each episode.
386+
"""
387+
plt.figure(figsize=(12, 6))
388+
for i, (actions, rewards, length) in enumerate(zip(episode_actions, episode_rewards, episode_lengths)):
389+
plt.plot(range(length), actions.cpu().numpy(), label=f'Episode {i+1} Actions')
390+
plt.plot(range(length), rewards.cpu().numpy(), label=f'Episode {i+1} Rewards', linestyle='--')
391+
392+
plt.xlabel('Time Step')
393+
plt.ylabel('Value')
394+
plt.title('Actions and Rewards per Episode')
395+
plt.legend()
396+
397+
if image_folder:
398+
plt.savefig(f"{image_folder}/actions_rewards_iter_{iter_idx}.png")
399+
else:
400+
plt.show()
401+

0 commit comments

Comments
 (0)