Skip to content

Commit 5a0119c

Browse files
epoint95Edward Point
andauthored
fix entropy bug from scalar to tensor input to loss function (#1524)
Co-authored-by: Edward Point <edwardpoint@MacBook-Pro.local>
1 parent c70cebc commit 5a0119c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

docs/tutorials/training_agents/vector_a2c.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def update_parameters(
442442
ep_value_preds = torch.zeros(n_steps_per_update, n_envs, device=device)
443443
ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
444444
ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device=device)
445+
ep_entropies = torch.zeros(n_steps_per_update, n_envs, device=device)
445446
masks = torch.zeros(n_steps_per_update, n_envs, device=device)
446447

447448
# at the start of training reset all envs to get an initial state
@@ -463,6 +464,7 @@ def update_parameters(
463464
ep_value_preds[step] = torch.squeeze(state_value_preds)
464465
ep_rewards[step] = torch.tensor(rewards, device=device)
465466
ep_action_log_probs[step] = action_log_probs
467+
ep_entropies[step] = entropy
466468

467469
# add a mask (for the return calculation later);
468470
# for each env the mask is 1 if the episode is ongoing and 0 if it is terminated (not by truncation!)
@@ -473,7 +475,7 @@ def update_parameters(
473475
ep_rewards,
474476
ep_action_log_probs,
475477
ep_value_preds,
476-
entropy,
478+
ep_entropies,
477479
masks,
478480
gamma,
479481
lam,
@@ -487,7 +489,7 @@ def update_parameters(
487489
# log the losses and entropy
488490
critic_losses.append(critic_loss.detach().cpu().numpy())
489491
actor_losses.append(actor_loss.detach().cpu().numpy())
490-
entropies.append(entropy.detach().mean().cpu().numpy())
492+
entropies.append(ep_entropies.detach().mean().cpu().numpy())
491493

492494

493495
# %%

0 commit comments

Comments
 (0)