Skip to content

Commit d3db406

Browse files
committed
next try
1 parent 57350b5 commit d3db406

File tree

2 files changed

+164
-96
lines changed

2 files changed

+164
-96
lines changed

ddopai/meta_learning/environments/pricing_env/pricing_env.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -259,21 +259,23 @@ def visualise_behaviour(env,
259259
Visualise behaviour in PricingEnv: plots price (action) and revenue (reward) per timestep.
260260
The environment passed to this method should be a vectorised env (DummyVecEnv or SubprocVecEnv).
261261
"""
262-
import matplotlib.pyplot as plt
263-
import torch
264-
import numpy as np
265-
266-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
267262

268263
num_episodes = args.max_rollouts_per_task
269264
unwrapped_env = env.venv.unwrapped.envs[0]
270265

266+
episode_all_obs = [[] for _ in range(num_episodes)]
271267
episode_prev_obs = [[] for _ in range(num_episodes)]
272268
episode_next_obs = [[] for _ in range(num_episodes)]
273269
episode_actions = [[] for _ in range(num_episodes)]
274270
episode_rewards = [[] for _ in range(num_episodes)]
275271
episode_returns = []
272+
episode_lengths = []
276273

274+
if args.pass_belief_to_policy and (encoder is None):
275+
episode_beliefs = [[] for _ in range(num_episodes)]
276+
else:
277+
episode_beliefs = None
278+
277279
if encoder is not None:
278280
episode_latent_samples = [[] for _ in range(num_episodes)]
279281
episode_latent_means = [[] for _ in range(num_episodes)]
@@ -282,70 +284,102 @@ def visualise_behaviour(env,
282284
episode_latent_samples = episode_latent_means = episode_latent_logvars = None
283285

284286
env.reset_task()
285-
state, belief, task = utl.reset_env(env, args)
286-
start_obs_raw = state.clone()
287-
task = task.view(-1) if task is not None else None
287+
[state, belief, task] = utl.reset_env(env, args)
288+
start_obs = state.clone()
288289

289-
hidden_state = torch.zeros((1, args.hidden_size), device=device) if hasattr(args, 'hidden_size') else None
290+
for episode_idx in range(num_episodes):
290291

291-
for ep_idx in range(num_episodes):
292-
obs = env.reset()
293292
curr_rollout_rew = []
294293

295-
if ep_idx == 0 and encoder is not None:
296-
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
297-
curr_latent_sample = curr_latent_sample[0].to(device)
298-
curr_latent_mean = curr_latent_mean[0].to(device)
299-
curr_latent_logvar = curr_latent_logvar[0].to(device)
294+
300295

301296
if encoder is not None:
302-
episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())
303-
episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())
304-
episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())
305-
306-
for t in range(env._max_episode_steps):
307-
prev_obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
308-
episode_prev_obs[ep_idx].append(prev_obs.squeeze(0).clone())
309-
310-
latent = utl.get_latent_for_policy(args, curr_latent_sample, curr_latent_mean, curr_latent_logvar)
311-
_, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)
312-
313-
obs, reward, done, info = env.step(action.cpu().numpy())
314-
obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
315-
316-
episode_next_obs[ep_idx].append(obs.squeeze(0).clone())
317-
episode_actions[ep_idx].append(action.squeeze(0).clone())
318-
episode_rewards[ep_idx].append(torch.tensor(reward, dtype=torch.float32, device=device).clone())
319-
curr_rollout_rew.append(reward)
297+
298+
if ep_idx == 0 and encoder is not None:
299+
# reset to prior
300+
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
301+
curr_latent_sample = curr_latent_sample[0].to(device)
302+
curr_latent_mean = curr_latent_mean[0].to(device)
303+
curr_latent_logvar = curr_latent_logvar[0].to(device)
304+
305+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
306+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
307+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
308+
309+
episode_all_obs[episode_idx].append(start_obs.clone())
310+
if args.pass_belief_to_policy and (encoder is None):
311+
episode_beliefs[episode_idx].append(belief)
312+
313+
for step_idx in range(env._max_episode_steps):
314+
315+
if step_idx == 1:
316+
prev_obs = start_obs.clone()
317+
else:
318+
prev_obs = state.clone()
319+
320+
episode_prev_obs[episode_idx].append(prev_obs)
321+
322+
# act
323+
_, action, _ = utl.select_action(args=args,
324+
policy=policy,
325+
state=state.view(-1),
326+
belief=belief,
327+
task=task,
328+
deterministic=True,
329+
latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None,
330+
latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None,
331+
latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None,
332+
)
333+
334+
# observe reward and next obs
335+
[state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)
320336

321337
if encoder is not None:
338+
# update task embedding
322339
curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
323-
action.reshape(1, -1).float().to(device),
324-
obs,
325-
torch.tensor([reward], dtype=torch.float32, device=device).unsqueeze(0),
340+
action.float().to(device),
341+
state,
342+
rew_raw.reshape((1, 1)).float().to(device),
326343
prev_obs,
327344
hidden_state,
328-
return_prior=False,
329-
)
330-
episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())
331-
episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())
332-
episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())
345+
return_prior=False)
346+
347+
episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
348+
episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
349+
episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())
350+
351+
episode_all_obs[episode_idx].append(state.clone())
352+
episode_next_obs[episode_idx].append(state.clone())
353+
episode_rewards[episode_idx].append(rew_raw.clone())
354+
episode_actions[episode_idx].append(action.clone())
355+
356+
curr_rollout_rew.append(rew_raw.clone())
357+
333358

334-
if done:
359+
if args.pass_belief_to_policy and (encoder is None):
360+
episode_beliefs[episode_idx].append(belief)
361+
362+
if infos[0]['done_mdp'] and not done:
363+
start_obs = infos[0]['start_state']
364+
start_obs = torch.from_numpy(start_obs).float().reshape((1, -1)).to(device)
335365
break
336366

337367
episode_returns.append(sum(curr_rollout_rew))
368+
episode_lengths.append(step_idx)
369+
338370

339-
# Stack episode data
340-
episode_prev_obs = [torch.stack(e) for e in episode_prev_obs]
341-
episode_next_obs = [torch.stack(e) for e in episode_next_obs]
342-
episode_actions = [torch.stack(e) for e in episode_actions]
343-
episode_rewards = [torch.stack(e) for e in episode_rewards]
371+
# clean up
344372

345373
if encoder is not None:
346374
episode_latent_means = [torch.stack(e) for e in episode_latent_means]
347375
episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]
348376

377+
episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
378+
episode_next_obs = [torch.cat(e) for e in episode_next_obs]
379+
episode_actions = [torch.cat(e) for e in episode_actions]
380+
episode_rewards = [torch.cat(e) for e in episode_rewards]
381+
382+
349383
# Plot price and reward trajectories
350384
import matplotlib.pyplot as plt
351385
plt.figure(figsize=(10, 3 * num_episodes))

nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,23 @@
279279
" Visualise behaviour in PricingEnv: plots price (action) and revenue (reward) per timestep.\n",
280280
" The environment passed to this method should be a vectorised env (DummyVecEnv or SubprocVecEnv).\n",
281281
" \"\"\"\n",
282-
" import matplotlib.pyplot as plt\n",
283-
" import torch\n",
284-
" import numpy as np\n",
285-
"\n",
286-
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
287282
"\n",
288283
" num_episodes = args.max_rollouts_per_task\n",
289284
" unwrapped_env = env.venv.unwrapped.envs[0]\n",
290285
"\n",
286+
" episode_all_obs = [[] for _ in range(num_episodes)]\n",
291287
" episode_prev_obs = [[] for _ in range(num_episodes)]\n",
292288
" episode_next_obs = [[] for _ in range(num_episodes)]\n",
293289
" episode_actions = [[] for _ in range(num_episodes)]\n",
294290
" episode_rewards = [[] for _ in range(num_episodes)]\n",
295291
" episode_returns = []\n",
292+
" episode_lengths = []\n",
296293
"\n",
294+
" if args.pass_belief_to_policy and (encoder is None):\n",
295+
" episode_beliefs = [[] for _ in range(num_episodes)]\n",
296+
" else:\n",
297+
" episode_beliefs = None\n",
298+
" \n",
297299
" if encoder is not None:\n",
298300
" episode_latent_samples = [[] for _ in range(num_episodes)]\n",
299301
" episode_latent_means = [[] for _ in range(num_episodes)]\n",
@@ -302,70 +304,102 @@
302304
" episode_latent_samples = episode_latent_means = episode_latent_logvars = None\n",
303305
"\n",
304306
" env.reset_task()\n",
305-
" state, belief, task = utl.reset_env(env, args)\n",
306-
" start_obs_raw = state.clone()\n",
307-
" task = task.view(-1) if task is not None else None\n",
307+
" [state, belief, task] = utl.reset_env(env, args)\n",
308+
" start_obs = state.clone()\n",
308309
"\n",
309-
" hidden_state = torch.zeros((1, args.hidden_size), device=device) if hasattr(args, 'hidden_size') else None\n",
310+
" for episode_idx in range(num_episodes):\n",
310311
"\n",
311-
" for ep_idx in range(num_episodes):\n",
312-
" obs = env.reset()\n",
313312
" curr_rollout_rew = []\n",
314313
"\n",
315-
" if ep_idx == 0 and encoder is not None:\n",
316-
" curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)\n",
317-
" curr_latent_sample = curr_latent_sample[0].to(device)\n",
318-
" curr_latent_mean = curr_latent_mean[0].to(device)\n",
319-
" curr_latent_logvar = curr_latent_logvar[0].to(device)\n",
314+
" \n",
320315
"\n",
321316
" if encoder is not None:\n",
322-
" episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())\n",
323-
" episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())\n",
324-
" episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())\n",
325-
"\n",
326-
" for t in range(env._max_episode_steps):\n",
327-
" prev_obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)\n",
328-
" episode_prev_obs[ep_idx].append(prev_obs.squeeze(0).clone())\n",
329-
"\n",
330-
" latent = utl.get_latent_for_policy(args, curr_latent_sample, curr_latent_mean, curr_latent_logvar)\n",
331-
" _, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)\n",
332-
"\n",
333-
" obs, reward, done, info = env.step(action.cpu().numpy())\n",
334-
" obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)\n",
335-
"\n",
336-
" episode_next_obs[ep_idx].append(obs.squeeze(0).clone())\n",
337-
" episode_actions[ep_idx].append(action.squeeze(0).clone())\n",
338-
" episode_rewards[ep_idx].append(torch.tensor(reward, dtype=torch.float32, device=device).clone())\n",
339-
" curr_rollout_rew.append(reward)\n",
317+
" \n",
318+
" if ep_idx == 0 and encoder is not None:\n",
319+
" # reset to prior\n",
320+
" curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)\n",
321+
" curr_latent_sample = curr_latent_sample[0].to(device)\n",
322+
" curr_latent_mean = curr_latent_mean[0].to(device)\n",
323+
" curr_latent_logvar = curr_latent_logvar[0].to(device)\n",
324+
" \n",
325+
" episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())\n",
326+
" episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())\n",
327+
" episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())\n",
328+
"\n",
329+
" episode_all_obs[episode_idx].append(start_obs.clone())\n",
330+
" if args.pass_belief_to_policy and (encoder is None):\n",
331+
" episode_beliefs[episode_idx].append(belief)\n",
332+
" \n",
333+
" for step_idx in range(env._max_episode_steps):\n",
334+
" \n",
335+
" if step_idx == 1:\n",
336+
" prev_obs = start_obs.clone()\n",
337+
" else:\n",
338+
" prev_obs = state.clone()\n",
339+
" \n",
340+
" episode_prev_obs[episode_idx].append(prev_obs)\n",
341+
" \n",
342+
" # act\n",
343+
" _, action, _ = utl.select_action(args=args,\n",
344+
" policy=policy,\n",
345+
" state=state.view(-1),\n",
346+
" belief=belief,\n",
347+
" task=task,\n",
348+
" deterministic=True,\n",
349+
" latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None,\n",
350+
" latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None,\n",
351+
" latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None,\n",
352+
" )\n",
353+
" \n",
354+
" # observe reward and next obs\n",
355+
" [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)\n",
340356
"\n",
341357
" if encoder is not None:\n",
358+
" # update task embedding\n",
342359
" curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(\n",
343-
" action.reshape(1, -1).float().to(device),\n",
344-
" obs,\n",
345-
" torch.tensor([reward], dtype=torch.float32, device=device).unsqueeze(0),\n",
360+
" action.float().to(device),\n",
361+
" state,\n",
362+
" rew_raw.reshape((1, 1)).float().to(device),\n",
346363
" prev_obs,\n",
347364
" hidden_state,\n",
348-
" return_prior=False,\n",
349-
" )\n",
350-
" episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())\n",
351-
" episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())\n",
352-
" episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())\n",
365+
" return_prior=False)\n",
366+
"\n",
367+
" episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())\n",
368+
" episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())\n",
369+
" episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())\n",
370+
"\n",
371+
" episode_all_obs[episode_idx].append(state.clone())\n",
372+
" episode_next_obs[episode_idx].append(state.clone())\n",
373+
" episode_rewards[episode_idx].append(rew_raw.clone())\n",
374+
" episode_actions[episode_idx].append(action.clone())\n",
375+
"\n",
376+
" curr_rollout_rew.append(rew_raw.clone())\n",
377+
" \n",
353378
"\n",
354-
" if done:\n",
379+
" if args.pass_belief_to_policy and (encoder is None):\n",
380+
" episode_beliefs[episode_idx].append(belief)\n",
381+
"\n",
382+
" if infos[0]['done_mdp'] and not done:\n",
383+
" start_obs = infos[0]['start_state']\n",
384+
" start_obs = torch.from_numpy(start_obs).float().reshape((1, -1)).to(device)\n",
355385
" break\n",
356386
"\n",
357387
" episode_returns.append(sum(curr_rollout_rew))\n",
388+
" episode_lengths.append(step_idx)\n",
389+
"\n",
358390
"\n",
359-
" # Stack episode data\n",
360-
" episode_prev_obs = [torch.stack(e) for e in episode_prev_obs]\n",
361-
" episode_next_obs = [torch.stack(e) for e in episode_next_obs]\n",
362-
" episode_actions = [torch.stack(e) for e in episode_actions]\n",
363-
" episode_rewards = [torch.stack(e) for e in episode_rewards]\n",
391+
" # clean up\n",
364392
"\n",
365393
" if encoder is not None:\n",
366394
" episode_latent_means = [torch.stack(e) for e in episode_latent_means]\n",
367395
" episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]\n",
368396
"\n",
397+
" episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]\n",
398+
" episode_next_obs = [torch.cat(e) for e in episode_next_obs]\n",
399+
" episode_actions = [torch.cat(e) for e in episode_actions]\n",
400+
" episode_rewards = [torch.cat(e) for e in episode_rewards]\n",
401+
"\n",
402+
"\n",
369403
" # Plot price and reward trajectories\n",
370404
" import matplotlib.pyplot as plt\n",
371405
" plt.figure(figsize=(10, 3 * num_episodes))\n",

0 commit comments

Comments
 (0)