Skip to content
This repository was archived by the owner on Jan 20, 2026. It is now read-only.

Commit f6a7e98

Browse files
authored
Soft Actor Critic (SAC) Model (#627)
1 parent 3f6b122 commit f6a7e98

File tree

11 files changed

+761
-5
lines changed

11 files changed

+761
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11-
- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))
1211

12+
- Added Soft Actor Critic (SAC) Model [#627](https://github.com/PyTorchLightning/lightning-bolts/pull/627))
1313

14-
- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))
14+
- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676))
1515

16+
- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))
1617

1718
- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))
1819

19-
2020
- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724))
2121

22-
2322
### Changed
2423

2524
- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
55.5 KB
Loading

docs/source/reinforce_learn.rst

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,3 +764,79 @@ Example::
764764

765765
.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic
766766
:noindex:
767+
768+
--------------
769+
770+
771+
Soft Actor Critic (SAC)
772+
^^^^^^^^^^^^^^^^^^^^^^^
773+
774+
Soft Actor Critic model introduced in `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor <https://arxiv.org/abs/1801.01290>`__
775+
Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine
776+
777+
Original implementation by: `Jason Wang <https://github.com/blahBlahhhJ>`_
778+
779+
Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a
780+
special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which
781+
means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such
782+
as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient.
783+
784+
The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards.
785+
The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the
786+
two as the predicted Q value.
787+
788+
Since SAC is off-policy, its algorithm's training step is quite similar to DQN:
789+
790+
1. Initialize one policy network, two Q networks, and two corresponding target Q networks.
791+
2. Run 1 step using action sampled from policy and store the transition into the replay buffer.
792+
793+
.. math::
794+
a \sim tanh(N(\mu_\pi(s), \sigma_\pi(s)))
795+
796+
3. Sample transitions (states, actions, rewards, dones, next states) from the replay buffer.
797+
798+
.. math::
799+
s, a, r, d, s' \sim B
800+
801+
4. Compute actor loss and update policy network.
802+
803+
.. math::
804+
J_\pi = \frac1n\sum_i(\log\pi(\pi(a | s_i) | s_i) - Q_{min}(s_i, \pi(a | s_i)))
805+
806+
5. Compute Q target
807+
808+
.. math::
809+
target_i = r_i + (1 - d_i) \gamma (\min_i Q_{target,i}(s'_i, \pi(a', s'_i)) - log\pi(\pi(a | s'_i) | s'_i))
810+
811+
5. Compute critic loss and update Q network..
812+
813+
.. math::
814+
J_{Q_i} = \frac1n \sum_i(Q_i(s_i, a_i) - target_i)^2
815+
816+
4. Soft update the target Q network using a weighted sum of itself and the Q network.
817+
818+
.. math::
819+
Q_{target,i} := \tau Q_{target,i} + (1-\tau) Q_i
820+
821+
SAC Benefits
822+
~~~~~~~~~~~~~~~~~~~
823+
824+
- More sample efficient due to off-policy training
825+
826+
- Supports continuous action space
827+
828+
SAC Results
829+
~~~~~~~~~~~~~~~~
830+
831+
.. image:: _images/rl_benchmark/pendulum_sac_results.jpg
832+
:width: 300
833+
:alt: SAC Results
834+
835+
Example::
836+
from pl_bolts.models.rl import SAC
837+
sac = SAC("Pendulum-v0")
838+
trainer = Trainer()
839+
trainer.fit(sac)
840+
841+
.. autoclass:: pl_bolts.models.rl.SAC
842+
:noindex:

pl_bolts/models/rl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
66
from pl_bolts.models.rl.per_dqn_model import PERDQN
77
from pl_bolts.models.rl.reinforce_model import Reinforce
8+
from pl_bolts.models.rl.sac_model import SAC
89
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
910

1011
__all__ = [
@@ -15,5 +16,6 @@
1516
"NoisyDQN",
1617
"PERDQN",
1718
"Reinforce",
19+
"SAC",
1820
"VanillaPolicyGradient",
1921
]

pl_bolts/models/rl/common/agents.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,48 @@ def __call__(self, states: Tensor, device: str) -> List[int]:
161161
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]
162162

163163
return actions
164+
165+
166+
class SoftActorCriticAgent(Agent):
167+
"""Actor-Critic based agent that returns a continuous action based on the policy."""
168+
169+
def __call__(self, states: Tensor, device: str) -> List[float]:
170+
"""Takes in the current state and returns the action based on the agents policy.
171+
172+
Args:
173+
states: current state of the environment
174+
device: the device used for the current batch
175+
176+
Returns:
177+
action defined by policy
178+
"""
179+
if not isinstance(states, list):
180+
states = [states]
181+
182+
if not isinstance(states, Tensor):
183+
states = torch.tensor(states, device=device)
184+
185+
dist = self.net(states)
186+
actions = [a for a in dist.sample().cpu().numpy()]
187+
188+
return actions
189+
190+
def get_action(self, states: Tensor, device: str) -> List[float]:
191+
"""Get the action greedily (without sampling)
192+
193+
Args:
194+
states: current state of the environment
195+
device: the device used for the current batch
196+
197+
Returns:
198+
action defined by policy
199+
"""
200+
if not isinstance(states, list):
201+
states = [states]
202+
203+
if not isinstance(states, Tensor):
204+
states = torch.tensor(states, device=device)
205+
206+
actions = [self.net.get_action(states).cpu().numpy()]
207+
208+
return actions
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Distributions used in some continuous RL algorithms."""
2+
import torch
3+
4+
5+
class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
6+
"""The distribution of X is an affine of tanh applied on a normal distribution.
7+
8+
X = action_scale * tanh(Z) + action_bias
9+
Z ~ Normal(mean, variance)
10+
"""
11+
12+
def __init__(self, action_bias, action_scale, **kwargs):
13+
super().__init__(**kwargs)
14+
15+
self.action_bias = action_bias
16+
self.action_scale = action_scale
17+
18+
def rsample_with_z(self, sample_shape=torch.Size()):
19+
"""Samples X using reparametrization trick with the intermediate variable Z.
20+
21+
Returns:
22+
Sampled X and Z
23+
"""
24+
z = super().rsample()
25+
return self.action_scale * torch.tanh(z) + self.action_bias, z
26+
27+
def log_prob_with_z(self, value, z):
28+
"""Computes the log probability of a sampled X.
29+
30+
Refer to the original paper of SAC for more details in equation (20), (21)
31+
32+
Args:
33+
value: the value of X
34+
z: the value of Z
35+
Returns:
36+
Log probability of the sample
37+
"""
38+
value = (value - self.action_bias) / self.action_scale
39+
z_logprob = super().log_prob(z)
40+
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
41+
return z_logprob - correction
42+
43+
def rsample_and_log_prob(self, sample_shape=torch.Size()):
44+
"""Samples X and computes the log probability of the sample.
45+
46+
Returns:
47+
Sampled X and log probability
48+
"""
49+
z = super().rsample()
50+
z_logprob = super().log_prob(z)
51+
value = torch.tanh(z)
52+
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
53+
return self.action_scale * value + self.action_bias, z_logprob - correction
54+
55+
def rsample(self, sample_shape=torch.Size()):
56+
fz, z = self.rsample_with_z(sample_shape)
57+
return fz
58+
59+
def log_prob(self, value):
60+
value = (value - self.action_bias) / self.action_scale
61+
z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
62+
return self.log_prob_with_z(value, z)

pl_bolts/models/rl/common/networks.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import numpy as np
66
import torch
7-
from torch import Tensor, nn
7+
from torch import FloatTensor, Tensor, nn
88
from torch.distributions import Categorical, Normal
99
from torch.nn import functional as F
1010

11+
from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal
12+
1113

1214
class CNN(nn.Module):
1315
"""Simple MLP network."""
@@ -84,6 +86,64 @@ def forward(self, input_x):
8486
return self.net(input_x.float())
8587

8688

89+
class ContinuousMLP(nn.Module):
90+
"""MLP network that outputs continuous value via Gaussian distribution."""
91+
92+
def __init__(
93+
self,
94+
input_shape: Tuple[int],
95+
n_actions: int,
96+
hidden_size: int = 128,
97+
action_bias: int = 0,
98+
action_scale: int = 1,
99+
):
100+
"""
101+
Args:
102+
input_shape: observation shape of the environment
103+
n_actions: dimension of actions in the environment
104+
hidden_size: size of hidden layers
105+
action_bias: the center of the action space
106+
action_scale: the scale of the action space
107+
"""
108+
super().__init__()
109+
self.action_bias = action_bias
110+
self.action_scale = action_scale
111+
112+
self.shared_net = nn.Sequential(
113+
nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()
114+
)
115+
self.mean_layer = nn.Linear(hidden_size, n_actions)
116+
self.logstd_layer = nn.Linear(hidden_size, n_actions)
117+
118+
def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
119+
"""Forward pass through network. Calculates the action distribution.
120+
121+
Args:
122+
x: input to network
123+
Returns:
124+
action distribution
125+
"""
126+
x = self.shared_net(x.float())
127+
batch_mean = self.mean_layer(x)
128+
logstd = torch.clamp(self.logstd_layer(x), -20, 2)
129+
batch_scale_tril = torch.diag_embed(torch.exp(logstd))
130+
return TanhMultivariateNormal(
131+
action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril
132+
)
133+
134+
def get_action(self, x: FloatTensor) -> Tensor:
135+
"""Get the action greedily (without sampling)
136+
137+
Args:
138+
x: input to network
139+
Returns:
140+
mean action
141+
"""
142+
x = self.shared_net(x.float())
143+
batch_mean = self.mean_layer(x)
144+
return self.action_scale * torch.tanh(batch_mean) + self.action_bias
145+
146+
87147
class ActorCriticMLP(nn.Module):
88148
"""MLP network with heads for actor and critic."""
89149

0 commit comments

Comments
 (0)