Skip to content

Commit 23cb60b

Browse files
Merge pull request #3 from MLDS-Laboratory/aarunsrinivas5/issue-1
Switch TD(lambda) to GAE
2 parents 9ac3954 + c5bc5ce commit 23cb60b

File tree

2 files changed

+58
-53
lines changed

2 files changed

+58
-53
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ src
4646
.cache
4747
*.lprof
4848
*.prof
49+
*.zip
4950

5051
MUJOCO_LOG.TXT
5152

53+
dummy.py
5254
rsa2c/
5355
exptd3/

stable_baselines3/common/buffers.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -420,34 +420,37 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
420420
:param dones: if the last step was a terminal step (one bool for each env).
421421
"""
422422
# # Convert to numpy
423-
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
424-
425-
# last_gae_lam = 0
426-
# for step in reversed(range(self.buffer_size)):
427-
# if step == self.buffer_size - 1:
428-
# next_non_terminal = 1.0 - dones.astype(np.float32)
429-
# next_values = last_values
430-
# else:
431-
# next_non_terminal = 1.0 - self.episode_starts[step + 1]
432-
# next_values = self.values[step + 1]
433-
# delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
434-
# last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
435-
# self.advantages[step] = last_gae_lam
436-
# # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
437-
# # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
438-
# self.returns = self.advantages + self.values
439-
440423
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
441-
values = np.concatenate((self.values, last_values.reshape(1, -1)))
442-
dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
443-
next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
444424

445-
returns = [self.values[-1]]
446-
interm = self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * values[1:]
425+
last_gae_lam = 0
447426
for step in reversed(range(self.buffer_size)):
448-
returns.append(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * returns[-1])
449-
self.returns = np.stack(list(reversed(returns))[:-1], 0)
450-
self.advantages = self.returns - self.values
427+
if step == self.buffer_size - 1:
428+
next_non_terminal = 1.0 - dones.astype(np.float32)
429+
next_values = last_values
430+
else:
431+
next_non_terminal = 1.0 - self.episode_starts[step + 1]
432+
next_values = self.values[step + 1]
433+
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
434+
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
435+
self.advantages[step] = last_gae_lam
436+
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
437+
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
438+
self.returns = self.advantages + self.values
439+
440+
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
441+
# values = np.concatenate((self.values, last_values.reshape(1, -1)))
442+
# dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
443+
# next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
444+
445+
# # self.returns = self.rewards + self.gamma * next_non_terminal * values[1:]
446+
# # self.advantages = self.returns - self.values
447+
448+
# returns = [self.values[-1]]
449+
# interm = self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * values[1:]
450+
# for step in reversed(range(self.buffer_size)):
451+
# returns.append(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * returns[-1])
452+
# self.returns = np.stack(list(reversed(returns))[:-1], 0)
453+
# self.advantages = self.returns - self.values
451454

452455
def add(
453456
self,
@@ -541,37 +544,37 @@ def __init__(self, buffer_size, observation_space, action_space, device = "auto"
541544

542545
def compute_returns_and_advantage(self, last_values, dones):
543546

544-
# # Convert to numpy
545-
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
547+
# Convert to numpy
548+
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
546549

547-
# last_gae_lam = 0
548-
# for step in reversed(range(self.buffer_size)):
549-
# if step == self.buffer_size - 1:
550-
# next_non_terminal = 1.0 - dones.astype(np.float32)
551-
# next_values = last_values
552-
# else:
553-
# next_non_terminal = 1.0 - self.episode_starts[step + 1]
554-
# next_values = self.values[step + 1]
555-
# delta = np.exp(self.beta * self.rewards[step] + self.gamma * np.log(1e-15 + np.maximum(next_values, 0)) * next_non_terminal) - self.values[step]
556-
# # delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
557-
# last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
558-
# self.advantages[step] = last_gae_lam
559-
# # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
560-
# # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
561-
# self.returns = self.advantages + self.values
550+
last_gae_lam = 0
551+
for step in reversed(range(self.buffer_size)):
552+
if step == self.buffer_size - 1:
553+
next_non_terminal = 1.0 - dones.astype(np.float32)
554+
next_values = last_values
555+
else:
556+
next_non_terminal = 1.0 - self.episode_starts[step + 1]
557+
next_values = self.values[step + 1]
558+
delta = np.exp(self.beta * self.rewards[step] + self.gamma * np.log(1e-15 + np.maximum(next_values, 0)) * next_non_terminal) - self.values[step]
559+
# delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
560+
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
561+
self.advantages[step] = last_gae_lam
562+
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
563+
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
564+
self.returns = self.advantages + self.values
562565

563566

564-
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
565-
values = np.concatenate((self.values, last_values.reshape(1, -1)))
566-
dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
567-
next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
568-
569-
returns = [self.values[-1]]
570-
interm = self.beta * self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * np.log(1e-15 + np.maximum(0, values[1:]))
571-
for step in reversed(range(self.buffer_size)):
572-
returns.append(np.exp(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * np.log(1e-15 + np.maximum(0, returns[-1]))))
573-
self.returns = np.stack(list(reversed(returns))[:-1], 0)
574-
self.advantages = self.returns - self.values
567+
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
568+
# values = np.concatenate((self.values, last_values.reshape(1, -1)))
569+
# dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
570+
# next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
571+
572+
# returns = [self.values[-1]]
573+
# interm = self.beta * self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * np.log(1e-15 + np.maximum(0, values[1:]))
574+
# for step in reversed(range(self.buffer_size)):
575+
# returns.append(np.exp(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * np.log(1e-15 + np.maximum(0, returns[-1]))))
576+
# self.returns = np.stack(list(reversed(returns))[:-1], 0)
577+
# self.advantages = (self.returns - self.values)
575578

576579

577580

0 commit comments

Comments
 (0)