|
279 | 279 | " Visualise behaviour in PricingEnv: plots price (action) and revenue (reward) per timestep.\n", |
280 | 280 | " The environment passed to this method should be a vectorised env (DummyVecEnv or SubprocVecEnv).\n", |
281 | 281 | " \"\"\"\n", |
282 | | - " import matplotlib.pyplot as plt\n", |
283 | | - " import torch\n", |
284 | | - " import numpy as np\n", |
285 | | - "\n", |
286 | | - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |
287 | 282 | "\n", |
288 | 283 | " num_episodes = args.max_rollouts_per_task\n", |
289 | 284 | " unwrapped_env = env.venv.unwrapped.envs[0]\n", |
290 | 285 | "\n", |
| 286 | + " episode_all_obs = [[] for _ in range(num_episodes)]\n", |
291 | 287 | " episode_prev_obs = [[] for _ in range(num_episodes)]\n", |
292 | 288 | " episode_next_obs = [[] for _ in range(num_episodes)]\n", |
293 | 289 | " episode_actions = [[] for _ in range(num_episodes)]\n", |
294 | 290 | " episode_rewards = [[] for _ in range(num_episodes)]\n", |
295 | 291 | " episode_returns = []\n", |
| 292 | + " episode_lengths = []\n", |
296 | 293 | "\n", |
| 294 | + " if args.pass_belief_to_policy and (encoder is None):\n", |
| 295 | + " episode_beliefs = [[] for _ in range(num_episodes)]\n", |
| 296 | + " else:\n", |
| 297 | + " episode_beliefs = None\n", |
| 298 | + " \n", |
297 | 299 | " if encoder is not None:\n", |
298 | 300 | " episode_latent_samples = [[] for _ in range(num_episodes)]\n", |
299 | 301 | " episode_latent_means = [[] for _ in range(num_episodes)]\n", |
|
302 | 304 | " episode_latent_samples = episode_latent_means = episode_latent_logvars = None\n", |
303 | 305 | "\n", |
304 | 306 | " env.reset_task()\n", |
305 | | - " state, belief, task = utl.reset_env(env, args)\n", |
306 | | - " start_obs_raw = state.clone()\n", |
307 | | - " task = task.view(-1) if task is not None else None\n", |
| 307 | + " [state, belief, task] = utl.reset_env(env, args)\n", |
| 308 | + " start_obs = state.clone()\n", |
308 | 309 | "\n", |
309 | | - " hidden_state = torch.zeros((1, args.hidden_size), device=device) if hasattr(args, 'hidden_size') else None\n", |
| 310 | + " for episode_idx in range(num_episodes):\n", |
310 | 311 | "\n", |
311 | | - " for ep_idx in range(num_episodes):\n", |
312 | | - " obs = env.reset()\n", |
313 | 312 | " curr_rollout_rew = []\n", |
314 | 313 | "\n", |
315 | | - " if ep_idx == 0 and encoder is not None:\n", |
316 | | - " curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)\n", |
317 | | - " curr_latent_sample = curr_latent_sample[0].to(device)\n", |
318 | | - " curr_latent_mean = curr_latent_mean[0].to(device)\n", |
319 | | - " curr_latent_logvar = curr_latent_logvar[0].to(device)\n", |
| 314 | + " \n", |
320 | 315 | "\n", |
321 | 316 | " if encoder is not None:\n", |
322 | | - " episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())\n", |
323 | | - " episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())\n", |
324 | | - " episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())\n", |
325 | | - "\n", |
326 | | - " for t in range(env._max_episode_steps):\n", |
327 | | - " prev_obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)\n", |
328 | | - " episode_prev_obs[ep_idx].append(prev_obs.squeeze(0).clone())\n", |
329 | | - "\n", |
330 | | - " latent = utl.get_latent_for_policy(args, curr_latent_sample, curr_latent_mean, curr_latent_logvar)\n", |
331 | | - " _, action, _ = policy.act(prev_obs, latent, belief=None, task=task, deterministic=True)\n", |
332 | | - "\n", |
333 | | - " obs, reward, done, info = env.step(action.cpu().numpy())\n", |
334 | | - " obs = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)\n", |
335 | | - "\n", |
336 | | - " episode_next_obs[ep_idx].append(obs.squeeze(0).clone())\n", |
337 | | - " episode_actions[ep_idx].append(action.squeeze(0).clone())\n", |
338 | | - " episode_rewards[ep_idx].append(torch.tensor(reward, dtype=torch.float32, device=device).clone())\n", |
339 | | - " curr_rollout_rew.append(reward)\n", |
| 317 | + " \n", |
| 318 | + " if ep_idx == 0 and encoder is not None:\n", |
| 319 | + " # reset to prior\n", |
| 320 | + " curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)\n", |
| 321 | + " curr_latent_sample = curr_latent_sample[0].to(device)\n", |
| 322 | + " curr_latent_mean = curr_latent_mean[0].to(device)\n", |
| 323 | + " curr_latent_logvar = curr_latent_logvar[0].to(device)\n", |
| 324 | + " \n", |
| 325 | + " episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())\n", |
| 326 | + " episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())\n", |
| 327 | + " episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())\n", |
| 328 | + "\n", |
| 329 | + " episode_all_obs[episode_idx].append(start_obs.clone())\n", |
| 330 | + " if args.pass_belief_to_policy and (encoder is None):\n", |
| 331 | + " episode_beliefs[episode_idx].append(belief)\n", |
| 332 | + " \n", |
| 333 | + " for step_idx in range(env._max_episode_steps):\n", |
| 334 | + " \n", |
| 335 | + " if step_idx == 1:\n", |
| 336 | + " prev_obs = start_obs.clone()\n", |
| 337 | + " else:\n", |
| 338 | + " prev_obs = state.clone()\n", |
| 339 | + " \n", |
| 340 | + " episode_prev_obs[episode_idx].append(prev_obs)\n", |
| 341 | + " \n", |
| 342 | + " # act\n", |
| 343 | + " _, action, _ = utl.select_action(args=args,\n", |
| 344 | + " policy=policy,\n", |
| 345 | + " state=state.view(-1),\n", |
| 346 | + " belief=belief,\n", |
| 347 | + " task=task,\n", |
| 348 | + " deterministic=True,\n", |
| 349 | + " latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None,\n", |
| 350 | + " latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None,\n", |
| 351 | + " latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None,\n", |
| 352 | + " )\n", |
| 353 | + " \n", |
| 354 | + " # observe reward and next obs\n", |
| 355 | + " [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)\n", |
340 | 356 | "\n", |
341 | 357 | " if encoder is not None:\n", |
| 358 | + " # update task embedding\n", |
342 | 359 | " curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(\n", |
343 | | - " action.reshape(1, -1).float().to(device),\n", |
344 | | - " obs,\n", |
345 | | - " torch.tensor([reward], dtype=torch.float32, device=device).unsqueeze(0),\n", |
| 360 | + " action.float().to(device),\n", |
| 361 | + " state,\n", |
| 362 | + " rew_raw.reshape((1, 1)).float().to(device),\n", |
346 | 363 | " prev_obs,\n", |
347 | 364 | " hidden_state,\n", |
348 | | - " return_prior=False,\n", |
349 | | - " )\n", |
350 | | - " episode_latent_samples[ep_idx].append(curr_latent_sample[0].clone())\n", |
351 | | - " episode_latent_means[ep_idx].append(curr_latent_mean[0].clone())\n", |
352 | | - " episode_latent_logvars[ep_idx].append(curr_latent_logvar[0].clone())\n", |
| 365 | + " return_prior=False)\n", |
| 366 | + "\n", |
| 367 | + " episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())\n", |
| 368 | + " episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())\n", |
| 369 | + " episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())\n", |
| 370 | + "\n", |
| 371 | + " episode_all_obs[episode_idx].append(state.clone())\n", |
| 372 | + " episode_next_obs[episode_idx].append(state.clone())\n", |
| 373 | + " episode_rewards[episode_idx].append(rew_raw.clone())\n", |
| 374 | + " episode_actions[episode_idx].append(action.clone())\n", |
| 375 | + "\n", |
| 376 | + " curr_rollout_rew.append(rew_raw.clone())\n", |
| 377 | + " \n", |
353 | 378 | "\n", |
354 | | - " if done:\n", |
| 379 | + " if args.pass_belief_to_policy and (encoder is None):\n", |
| 380 | + " episode_beliefs[episode_idx].append(belief)\n", |
| 381 | + "\n", |
| 382 | + " if infos[0]['done_mdp'] and not done:\n", |
| 383 | + " start_obs = infos[0]['start_state']\n", |
| 384 | + " start_obs = torch.from_numpy(start_obs).float().reshape((1, -1)).to(device)\n", |
355 | 385 | " break\n", |
356 | 386 | "\n", |
357 | 387 | " episode_returns.append(sum(curr_rollout_rew))\n", |
| 388 | + " episode_lengths.append(step_idx)\n", |
| 389 | + "\n", |
358 | 390 | "\n", |
359 | | - " # Stack episode data\n", |
360 | | - " episode_prev_obs = [torch.stack(e) for e in episode_prev_obs]\n", |
361 | | - " episode_next_obs = [torch.stack(e) for e in episode_next_obs]\n", |
362 | | - " episode_actions = [torch.stack(e) for e in episode_actions]\n", |
363 | | - " episode_rewards = [torch.stack(e) for e in episode_rewards]\n", |
| 391 | + " # clean up\n", |
364 | 392 | "\n", |
365 | 393 | " if encoder is not None:\n", |
366 | 394 | " episode_latent_means = [torch.stack(e) for e in episode_latent_means]\n", |
367 | 395 | " episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]\n", |
368 | 396 | "\n", |
| 397 | + " episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]\n", |
| 398 | + " episode_next_obs = [torch.cat(e) for e in episode_next_obs]\n", |
| 399 | + " episode_actions = [torch.cat(e) for e in episode_actions]\n", |
| 400 | + " episode_rewards = [torch.cat(e) for e in episode_rewards]\n", |
| 401 | + "\n", |
| 402 | + "\n", |
369 | 403 | " # Plot price and reward trajectories\n", |
370 | 404 | " import matplotlib.pyplot as plt\n", |
371 | 405 | " plt.figure(figsize=(10, 3 * num_episodes))\n", |
|
0 commit comments