Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@ src
.cache
*.lprof
*.prof
*.zip

MUJOCO_LOG.TXT

dummy.py
rsa2c/
exptd3/
train.py
11 changes: 11 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO, RSPPO
from stable_baselines3.common.utils import set_random_seed

set_random_seed(42)

env = gym.make('CartPole-v1')
model = RSPPO('MlpPolicy', env, verbose=1)

model.learn(total_timesteps=1e6)
5 changes: 5 additions & 0 deletions stable_baselines3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3
from stable_baselines3.rsppo import RSPPO
# from stable_baselines3.rsa2c import RSA2C
# from stable_baselines3.exptd3 import EXPTD3

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand All @@ -29,6 +32,8 @@ def HER(*args, **kwargs):
"PPO",
"SAC",
"TD3",
"RSPPO",
# "RSA2C",
"HerReplayBuffer",
"get_system_info",
]
59 changes: 58 additions & 1 deletion stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
# # Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]

last_gae_lam = 0
Expand All @@ -437,6 +437,21 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values

# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
# values = np.concatenate((self.values, last_values.reshape(1, -1)))
# dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
# next_non_terminal = (1.0 - dones.astype(np.float32))[1:]

# # self.returns = self.rewards + self.gamma * next_non_terminal * values[1:]
# # self.advantages = self.returns - self.values

# returns = [self.values[-1]]
# interm = self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * values[1:]
# for step in reversed(range(self.buffer_size)):
# returns.append(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * returns[-1])
# self.returns = np.stack(list(reversed(returns))[:-1], 0)
# self.advantages = self.returns - self.values

def add(
self,
obs: np.ndarray,
Expand Down Expand Up @@ -521,6 +536,48 @@ def _get_samples(
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))


class ExpRolloutBuffer(RolloutBuffer):

def __init__(self, buffer_size, observation_space, action_space, device = "auto", gae_lambda = 0.95, gamma = 0.99, n_envs = 1, beta = 0):
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
self.beta = beta

def compute_returns_and_advantage(self, last_values, dones):

# # Convert to numpy
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]

# last_gae_lam = 0
# for step in reversed(range(self.buffer_size)):
# if step == self.buffer_size - 1:
# next_non_terminal = 1.0 - dones.astype(np.float32)
# next_values = last_values
# else:
# next_non_terminal = 1.0 - self.episode_starts[step + 1]
# next_values = self.values[step + 1]
# delta = np.exp(self.beta * self.rewards[step] + self.gamma * np.log(1e-15 + np.maximum(next_values, 0)) * next_non_terminal) - self.values[step]
# # delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
# last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
# self.advantages[step] = last_gae_lam
# # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
# self.returns = self.advantages + self.values

last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
values = np.concatenate((self.values, last_values.reshape(1, -1)))
dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
next_non_terminal = (1.0 - dones.astype(np.float32))[1:]

returns = [self.values[-1]]
interm = self.beta * self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * np.log(1e-15 + np.maximum(0, values[1:]))
for step in reversed(range(self.buffer_size)):
returns.append(np.exp(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * np.log(1e-15 + np.maximum(0, returns[-1]))))
self.returns = np.stack(list(reversed(returns))[:-1], 0)
self.advantages = (self.returns - self.values)




class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def train(self) -> None:
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)

Expand Down
4 changes: 4 additions & 0 deletions stable_baselines3/rsppo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from stable_baselines3.rsppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from stable_baselines3.rsppo.rsppo import RSPPO

__all__ = ["PPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
7 changes: 7 additions & 0 deletions stable_baselines3/rsppo/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy

MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
Loading