Skip to content

Commit 5b6f908

Browse files
Add Risk Sensitive PPO (#1)
1 parent 656de97 commit 5b6f908

File tree

8 files changed

+420
-15
lines changed

8 files changed

+420
-15
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ src
4848
*.prof
4949

5050
MUJOCO_LOG.TXT
51+
52+
rsa2c/
53+
exptd3/

example.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
import gymnasium as gym
3+
from stable_baselines3 import PPO
4+
from stable_baselines3.common.utils import set_random_seed
5+
6+
set_random_seed(42)
7+
8+
env = gym.make('CartPole-v1')
9+
model = PPO('MlpPolicy', env, verbose=1)
10+
11+
model.learn(total_timesteps=1e6)

stable_baselines3/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from stable_baselines3.ppo import PPO
99
from stable_baselines3.sac import SAC
1010
from stable_baselines3.td3 import TD3
11+
from stable_baselines3.rsppo import RSPPO
12+
# from stable_baselines3.rsa2c import RSA2C
13+
# from stable_baselines3.exptd3 import EXPTD3
1114

1215
# Read version from file
1316
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
@@ -29,6 +32,8 @@ def HER(*args, **kwargs):
2932
"PPO",
3033
"SAC",
3134
"TD3",
35+
"RSPPO",
36+
# "RSA2C",
3237
"HerReplayBuffer",
3338
"get_system_info",
3439
]

stable_baselines3/common/buffers.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -419,23 +419,35 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
419419
:param last_values: state value estimation for the last step (one for each env)
420420
:param dones: if the last step was a terminal step (one bool for each env).
421421
"""
422-
# Convert to numpy
422+
# # Convert to numpy
423+
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
424+
425+
# last_gae_lam = 0
426+
# for step in reversed(range(self.buffer_size)):
427+
# if step == self.buffer_size - 1:
428+
# next_non_terminal = 1.0 - dones.astype(np.float32)
429+
# next_values = last_values
430+
# else:
431+
# next_non_terminal = 1.0 - self.episode_starts[step + 1]
432+
# next_values = self.values[step + 1]
433+
# delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
434+
# last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
435+
# self.advantages[step] = last_gae_lam
436+
# # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
437+
# # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
438+
# self.returns = self.advantages + self.values
439+
423440
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
441+
values = np.concatenate((self.values, last_values.reshape(1, -1)))
442+
dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
443+
next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
424444

425-
last_gae_lam = 0
445+
returns = [self.values[-1]]
446+
interm = self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * values[1:]
426447
for step in reversed(range(self.buffer_size)):
427-
if step == self.buffer_size - 1:
428-
next_non_terminal = 1.0 - dones.astype(np.float32)
429-
next_values = last_values
430-
else:
431-
next_non_terminal = 1.0 - self.episode_starts[step + 1]
432-
next_values = self.values[step + 1]
433-
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
434-
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
435-
self.advantages[step] = last_gae_lam
436-
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
437-
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
438-
self.returns = self.advantages + self.values
448+
returns.append(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * returns[-1])
449+
self.returns = np.stack(list(reversed(returns))[:-1], 0)
450+
self.advantages = self.returns - self.values
439451

440452
def add(
441453
self,
@@ -521,6 +533,49 @@ def _get_samples(
521533
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
522534

523535

536+
class ExpRolloutBuffer(RolloutBuffer):
537+
538+
def __init__(self, buffer_size, observation_space, action_space, device = "auto", gae_lambda = 0.95, gamma = 0.99, n_envs = 1, beta = 0):
539+
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
540+
self.beta = beta
541+
542+
def compute_returns_and_advantage(self, last_values, dones):
543+
544+
# # Convert to numpy
545+
# last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
546+
547+
# last_gae_lam = 0
548+
# for step in reversed(range(self.buffer_size)):
549+
# if step == self.buffer_size - 1:
550+
# next_non_terminal = 1.0 - dones.astype(np.float32)
551+
# next_values = last_values
552+
# else:
553+
# next_non_terminal = 1.0 - self.episode_starts[step + 1]
554+
# next_values = self.values[step + 1]
555+
# delta = np.exp(self.beta * self.rewards[step] + self.gamma * np.log(1e-15 + np.maximum(next_values, 0)) * next_non_terminal) - self.values[step]
556+
# # delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
557+
# last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
558+
# self.advantages[step] = last_gae_lam
559+
# # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
560+
# # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
561+
# self.returns = self.advantages + self.values
562+
563+
564+
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
565+
values = np.concatenate((self.values, last_values.reshape(1, -1)))
566+
dones = np.concatenate((self.episode_starts, dones.reshape(1, -1)))
567+
next_non_terminal = (1.0 - dones.astype(np.float32))[1:]
568+
569+
returns = [self.values[-1]]
570+
interm = self.beta * self.rewards + self.gamma * (1 - self.gae_lambda) * next_non_terminal * np.log(1e-15 + np.maximum(0, values[1:]))
571+
for step in reversed(range(self.buffer_size)):
572+
returns.append(np.exp(interm[step] + self.gamma * self.gae_lambda * next_non_terminal[step] * np.log(1e-15 + np.maximum(0, returns[-1]))))
573+
self.returns = np.stack(list(reversed(returns))[:-1], 0)
574+
self.advantages = self.returns - self.values
575+
576+
577+
578+
524579
class DictReplayBuffer(ReplayBuffer):
525580
"""
526581
Dict Replay buffer used in off-policy algorithms like SAC/TD3.

stable_baselines3/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def train(self) -> None:
217217
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
218218
if self.normalize_advantage and len(advantages) > 1:
219219
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
220-
220+
221221
# ratio between old and new policy, should be one at the first iteration
222222
ratio = th.exp(log_prob - rollout_data.old_log_prob)
223223

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from stable_baselines3.rsppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2+
from stable_baselines3.rsppo.rsppo import RSPPO
3+
4+
__all__ = ["PPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# This file is here just to define MlpPolicy/CnnPolicy
2+
# that work for PPO
3+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
4+
5+
MlpPolicy = ActorCriticPolicy
6+
CnnPolicy = ActorCriticCnnPolicy
7+
MultiInputPolicy = MultiInputActorCriticPolicy

0 commit comments

Comments
 (0)