|
1 | 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb. |
2 | 2 |
|
3 | 3 | # %% auto 0 |
4 | | -__all__ = ['device', 'PricingEnv'] |
| 4 | +__all__ = ['device', 'PricingEnv', 'plot_actions_reward'] |
5 | 5 |
|
6 | 6 | # %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 1 |
7 | 7 | import gym |
@@ -261,153 +261,141 @@ def visualise_behaviour(env, |
261 | 261 | """ |
262 | 262 |
|
263 | 263 | num_episodes = args.max_rollouts_per_task |
264 | | - unwrapped_env = env.venv.unwrapped.envs[0] |
265 | 264 |
|
266 | | - episode_all_obs = [[] for _ in range(num_episodes)] |
| 265 | + # --- initialise things we want to keep track of --- |
| 266 | + |
267 | 267 | episode_prev_obs = [[] for _ in range(num_episodes)] |
268 | 268 | episode_next_obs = [[] for _ in range(num_episodes)] |
269 | 269 | episode_actions = [[] for _ in range(num_episodes)] |
270 | 270 | episode_rewards = [[] for _ in range(num_episodes)] |
| 271 | + |
271 | 272 | episode_returns = [] |
272 | 273 | episode_lengths = [] |
273 | 274 |
|
274 | | - if args.pass_belief_to_policy and (encoder is None): |
275 | | - episode_beliefs = [[] for _ in range(num_episodes)] |
276 | | - else: |
277 | | - episode_beliefs = None |
278 | | - |
279 | 275 | if encoder is not None: |
280 | 276 | episode_latent_samples = [[] for _ in range(num_episodes)] |
281 | 277 | episode_latent_means = [[] for _ in range(num_episodes)] |
282 | 278 | episode_latent_logvars = [[] for _ in range(num_episodes)] |
283 | 279 | else: |
284 | | - episode_latent_samples = episode_latent_means = episode_latent_logvars = None |
| 280 | + curr_latent_sample = curr_latent_mean = curr_latent_logvar = None |
| 281 | + episode_latent_means = episode_latent_logvars = None |
| 282 | + |
| 283 | + # --- roll out policy --- |
285 | 284 |
|
| 285 | + # (re)set environment |
286 | 286 | env.reset_task() |
287 | | - [state, belief, task] = utl.reset_env(env, args) |
288 | | - start_obs = state.clone() |
| 287 | + state, belief, task = utl.reset_env(env, args) |
| 288 | + state = state.reshape((1, -1)).to(device) |
| 289 | + task = task.view(-1) if task is not None else None |
289 | 290 |
|
290 | 291 | for episode_idx in range(num_episodes): |
291 | 292 |
|
292 | 293 | curr_rollout_rew = [] |
293 | 294 |
|
294 | | - |
295 | | - |
296 | 295 | if encoder is not None: |
297 | | - |
298 | | - if episode_idx == 0 and encoder is not None: |
| 296 | + if episode_idx == 0: |
299 | 297 | # reset to prior |
300 | 298 | curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1) |
301 | 299 | curr_latent_sample = curr_latent_sample[0].to(device) |
302 | 300 | curr_latent_mean = curr_latent_mean[0].to(device) |
303 | 301 | curr_latent_logvar = curr_latent_logvar[0].to(device) |
304 | | - |
305 | 302 | episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone()) |
306 | 303 | episode_latent_means[episode_idx].append(curr_latent_mean[0].clone()) |
307 | 304 | episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone()) |
308 | 305 |
|
309 | | - episode_all_obs[episode_idx].append(start_obs.clone()) |
310 | | - if args.pass_belief_to_policy and (encoder is None): |
311 | | - episode_beliefs[episode_idx].append(belief) |
312 | | - |
313 | | - for step_idx in range(env._max_episode_steps): |
314 | | - |
315 | | - if step_idx == 1: |
316 | | - prev_obs = start_obs.clone() |
317 | | - else: |
318 | | - prev_obs = state.clone() |
319 | | - |
320 | | - episode_prev_obs[episode_idx].append(prev_obs) |
321 | | - |
322 | | - # act |
323 | | - _, action, _ = utl.select_action(args=args, |
324 | | - policy=policy, |
325 | | - state=state.view(-1), |
326 | | - belief=belief, |
327 | | - task=task, |
328 | | - deterministic=True, |
329 | | - latent_sample=curr_latent_sample.view(-1) if (curr_latent_sample is not None) else None, |
330 | | - latent_mean=curr_latent_mean.view(-1) if (curr_latent_mean is not None) else None, |
331 | | - latent_logvar=curr_latent_logvar.view(-1) if (curr_latent_logvar is not None) else None, |
332 | | - ) |
333 | | - |
| 306 | + for step_idx in range(1, env._max_episode_steps + 1): |
| 307 | + |
| 308 | + episode_prev_obs[episode_idx].append(state.clone()) |
| 309 | + prev_state = state.clone() |
| 310 | + |
| 311 | + latent = utl.get_latent_for_policy(args, |
| 312 | + latent_sample=curr_latent_sample, |
| 313 | + latent_mean=curr_latent_mean, |
| 314 | + latent_logvar=curr_latent_logvar) |
| 315 | + _, action, _ = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True) |
| 316 | + action = action.reshape((1, *action.shape)) |
| 317 | + |
334 | 318 | # observe reward and next obs |
335 | | - [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args) |
| 319 | + (state, belief, task), (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args) |
| 320 | + state = state.reshape((1, -1)).to(device) |
| 321 | + task = task.view(-1) if task is not None else None |
336 | 322 |
|
337 | 323 | if encoder is not None: |
338 | 324 | # update task embedding |
339 | 325 | curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder( |
340 | 326 | action.float().to(device), |
341 | 327 | state, |
342 | 328 | rew_raw.reshape((1, 1)).float().to(device), |
343 | | - prev_obs, |
| 329 | + prev_state, |
344 | 330 | hidden_state, |
345 | 331 | return_prior=False) |
346 | 332 |
|
347 | 333 | episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone()) |
348 | 334 | episode_latent_means[episode_idx].append(curr_latent_mean[0].clone()) |
349 | 335 | episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone()) |
350 | 336 |
|
351 | | - episode_all_obs[episode_idx].append(state.clone()) |
352 | 337 | episode_next_obs[episode_idx].append(state.clone()) |
353 | 338 | episode_rewards[episode_idx].append(rew_raw.clone()) |
354 | 339 | episode_actions[episode_idx].append(action.clone()) |
355 | 340 |
|
356 | | - curr_rollout_rew.append(rew_raw.clone()) |
357 | | - |
358 | | - |
359 | | - if args.pass_belief_to_policy and (encoder is None): |
360 | | - episode_beliefs[episode_idx].append(belief) |
361 | | - |
362 | | - if infos[0]['done_mdp'] and not done: |
363 | | - start_obs = infos[0]['start_state'] |
364 | | - start_obs = torch.from_numpy(start_obs).float().reshape((1, -1)).to(device) |
| 341 | + if infos[0]['done_mdp']: |
365 | 342 | break |
366 | 343 |
|
367 | 344 | episode_returns.append(sum(curr_rollout_rew)) |
368 | 345 | episode_lengths.append(step_idx) |
369 | 346 |
|
370 | | - |
371 | 347 | # clean up |
372 | | - |
373 | 348 | if encoder is not None: |
374 | 349 | episode_latent_means = [torch.stack(e) for e in episode_latent_means] |
375 | 350 | episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars] |
376 | 351 |
|
377 | 352 | episode_prev_obs = [torch.cat(e) for e in episode_prev_obs] |
378 | 353 | episode_next_obs = [torch.cat(e) for e in episode_next_obs] |
379 | 354 | episode_actions = [torch.cat(e) for e in episode_actions] |
380 | | - episode_rewards = [torch.cat(e) for e in episode_rewards] |
381 | | - |
382 | | - |
383 | | - # Plot price and reward trajectories |
384 | | - import matplotlib.pyplot as plt |
385 | | - plt.figure(figsize=(10, 3 * num_episodes)) |
386 | | - for i in range(num_episodes): |
387 | | - plt.subplot(num_episodes, 2, 2 * i + 1) |
388 | | - plt.plot(episode_actions[i].cpu().numpy(), label="Price") |
389 | | - plt.ylabel("Price") |
390 | | - plt.xlabel("Timestep") |
391 | | - plt.title(f"Episode {i}: Price") |
392 | | - |
393 | | - plt.subplot(num_episodes, 2, 2 * i + 2) |
394 | | - plt.plot(episode_rewards[i].cpu().numpy(), label="Revenue", color='green') |
395 | | - plt.ylabel("Revenue") |
396 | | - plt.xlabel("Timestep") |
397 | | - plt.title(f"Episode {i}: Revenue") |
398 | | - |
399 | | - plt.tight_layout() |
400 | | - if image_folder is not None: |
401 | | - plt.savefig(f"{image_folder}/{iter_idx}_pricing_behaviour.png") |
402 | | - plt.close() |
403 | | - else: |
404 | | - plt.show() |
405 | | - |
| 355 | + episode_rewards = [torch.cat(r) for r in episode_rewards] |
| 356 | + |
| 357 | + plot_actions_reward( |
| 358 | + episode_actions=episode_actions, |
| 359 | + episode_rewards=episode_rewards, |
| 360 | + episode_lengths=episode_lengths, |
| 361 | + image_folder=image_folder, |
| 362 | + iter_idx=iter_idx |
| 363 | + ) |
406 | 364 | return episode_latent_means, episode_latent_logvars, \ |
407 | | - episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ |
408 | | - episode_returns |
| 365 | + episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \ |
| 366 | + episode_returns |
| 367 | + |
| 368 | + |
| 369 | + |
409 | 370 |
|
410 | 371 |
|
411 | 372 |
|
412 | 373 |
|
413 | 374 |
|
| 375 | + |
| 376 | +# %% ../../../../nbs/50_meta_learning/53_environments/01_pricing_env/10_pricing_env.ipynb 3 |
| 377 | +def plot_actions_reward( |
| 378 | + episode_actions: List[torch.Tensor], |
| 379 | + episode_rewards: List[torch.Tensor], |
| 380 | + episode_lengths: List[int], |
| 381 | + image_folder: Optional[str] = None, |
| 382 | + iter_idx: int = 0 |
| 383 | +): |
| 384 | + """ |
| 385 | + Plot actions and rewards for each episode. |
| 386 | + """ |
| 387 | + plt.figure(figsize=(12, 6)) |
| 388 | + for i, (actions, rewards, length) in enumerate(zip(episode_actions, episode_rewards, episode_lengths)): |
| 389 | + plt.plot(range(length), actions.cpu().numpy(), label=f'Episode {i+1} Actions') |
| 390 | + plt.plot(range(length), rewards.cpu().numpy(), label=f'Episode {i+1} Rewards', linestyle='--') |
| 391 | + |
| 392 | + plt.xlabel('Time Step') |
| 393 | + plt.ylabel('Value') |
| 394 | + plt.title('Actions and Rewards per Episode') |
| 395 | + plt.legend() |
| 396 | + |
| 397 | + if image_folder: |
| 398 | + plt.savefig(f"{image_folder}/actions_rewards_iter_{iter_idx}.png") |
| 399 | + else: |
| 400 | + plt.show() |
| 401 | + |
0 commit comments