Skip to content

Commit bae46c2

Browse files
author
Ervin T
committed
[bug-fix] Use correct memories for LSTM SAC (#5228)
* Use correct memories for LSTM SAC * Add some comments (cherry picked from commit 7077302)
1 parent efa8f34 commit bae46c2

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -497,30 +497,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
497497
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
498498
)
499499
]
500-
offset = 1 if self.policy.sequence_length > 1 else 0
501-
next_value_memories_list = [
502-
ModelUtils.list_to_tensor(
503-
batch[BufferKey.CRITIC_MEMORY][i]
504-
) # only pass value part of memory to target network
505-
for i in range(
506-
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
507-
)
508-
]
509500

510501
if len(memories_list) > 0:
511502
memories = torch.stack(memories_list).unsqueeze(0)
512503
value_memories = torch.stack(value_memories_list).unsqueeze(0)
513-
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0)
514504
else:
515505
memories = None
516506
value_memories = None
517-
next_value_memories = None
518507

519508
# Q and V network memories are 0'ed out, since we don't have them during inference.
520509
q_memories = (
521-
torch.zeros_like(next_value_memories)
522-
if next_value_memories is not None
523-
else None
510+
torch.zeros_like(value_memories) if value_memories is not None else None
524511
)
525512

526513
# Copy normalizers from policy
@@ -568,6 +555,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
568555
q1_stream, q2_stream = q1_out, q2_out
569556

570557
with torch.no_grad():
558+
# Since we didn't record the next value memories, evaluate one step in the critic to
559+
# get them.
560+
if value_memories is not None:
561+
# Get the first observation in each sequence
562+
just_first_obs = [
563+
_obs[:: self.policy.sequence_length] for _obs in current_obs
564+
]
565+
_, next_value_memories = self._critic.critic_pass(
566+
just_first_obs, value_memories, sequence_length=1
567+
)
568+
else:
569+
next_value_memories = None
571570
target_values, _ = self.target_network(
572571
next_obs,
573572
memories=next_value_memories,

0 commit comments

Comments
 (0)