Skip to content

Commit 30a1984

Browse files
AlexPasquaaraffin
andauthored
Deprecation of shared layers in MlpExtractor (#1252)
* Deprecation warning for shared layers in Mlpextractor * Updated changelog * Updated custom policy doc * Update doc and deprecation * Fix doc build * Minor edits Co-authored-by: Antonin Raffin <[email protected]>
1 parent 4fa17dc commit 30a1984

File tree

8 files changed

+153
-107
lines changed

8 files changed

+153
-107
lines changed

docs/guide/custom_policy.rst

Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ Each of these network have a features extractor followed by a fully-connected ne
5151
.. image:: ../_static/img/sb3_policy.png
5252

5353

54-
.. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
55-
5654

5755
Custom Network Architecture
5856
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -90,13 +88,13 @@ using ``policy_kwargs`` parameter:
9088
# of two layers of size 32 each with Relu activation function
9189
# Note: an extra linear layer will be added on top of the pi and the vf nets, respectively
9290
policy_kwargs = dict(activation_fn=th.nn.ReLU,
93-
net_arch=[dict(pi=[32, 32], vf=[32, 32])])
91+
net_arch=dict(pi=[32, 32], vf=[32, 32]))
9492
# Create the agent
9593
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
9694
# Retrieve the environment
9795
env = model.get_env()
9896
# Train the agent
99-
model.learn(total_timesteps=100000)
97+
model.learn(total_timesteps=20_000)
10098
# Save the agent
10199
model.save("ppo_cartpole")
102100
@@ -114,13 +112,14 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
114112

115113
.. note::
116114

117-
By default the features extractor is shared between the actor and the critic to save computation (when applicable).
115+
For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable).
118116
However, this can be changed setting ``share_features_extractor=False`` in the
119117
``policy_kwargs`` (both for on-policy and off-policy algorithms).
120118

121119

122120
.. warning::
123121
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
122+
Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared.
124123

125124

126125
.. code-block:: python
@@ -240,64 +239,56 @@ downsampling and "vector" with a single linear layer.
240239
On-Policy Algorithms
241240
^^^^^^^^^^^^^^^^^^^^
242241

243-
Shared Networks
242+
Custom Networks
244243
---------------
245244

246-
The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many
247-
of them are shared between the policy network and the value network. It is assumed to be a list with the following
248-
structure:
249-
250-
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
251-
If the number of ints is zero, there will be no shared layers.
252-
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
253-
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
254-
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
255-
256-
In short: ``[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])]``.
257-
258-
Examples
259-
~~~~~~~~
260-
261-
Two shared layers of size 128: ``net_arch=[128, 128]``
262-
263-
264-
.. code-block:: none
245+
.. warning::
246+
Shared layers in the the ``mlp_extractor`` are **deprecated**.
247+
In a future release all layers will have to be non-shared.
248+
If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_).
265249

266-
obs
267-
|
268-
<128>
269-
|
270-
<128>
271-
/ \
272-
action value
250+
.. warning::
251+
In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change
252+
to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently)
253+
for the actor and the critic, with the same architecture.
273254

274255

275-
Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])]``
256+
If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``,
257+
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>])``.
276258

277-
.. code-block:: none
259+
For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks,
260+
then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``.
278261

279-
obs
280-
|
281-
<128>
282-
/ \
283-
action <256>
284-
|
285-
<256>
286-
|
287-
value
262+
.. Otherwise, to have actor and critic that share the same network architecture,
263+
.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
288264
265+
Examples
266+
~~~~~~~~
289267

290-
Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]``
268+
.. TODO(antonin): uncomment when shared network is removed
269+
.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
270+
..
271+
.. .. code-block:: none
272+
..
273+
.. obs
274+
.. / \
275+
.. <128> <128>
276+
.. | |
277+
.. <128> <128>
278+
.. | |
279+
.. action value
280+
281+
Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])``
291282

292283
.. code-block:: none
293284
294-
obs
295-
|
296-
<128>
297-
/ \
298-
<16> <256>
299-
| |
300-
action value
285+
obs
286+
/ \
287+
<32> <64>
288+
| |
289+
<32> <64>
290+
| |
291+
action value
301292
302293
303294
Advanced Example
@@ -334,7 +325,7 @@ If your task requires even more granular control over the policy/value architect
334325
last_layer_dim_pi: int = 64,
335326
last_layer_dim_vf: int = 64,
336327
):
337-
super(CustomNetwork, self).__init__()
328+
super().__init__()
338329
339330
# IMPORTANT:
340331
# Save output dimensions, used to create the distributions
@@ -370,8 +361,6 @@ If your task requires even more granular control over the policy/value architect
370361
observation_space: spaces.Space,
371362
action_space: spaces.Space,
372363
lr_schedule: Callable[[float], float],
373-
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
374-
activation_fn: Type[nn.Module] = nn.Tanh,
375364
*args,
376365
**kwargs,
377366
):
@@ -380,8 +369,6 @@ If your task requires even more granular control over the policy/value architect
380369
observation_space,
381370
action_space,
382371
lr_schedule,
383-
net_arch,
384-
activation_fn,
385372
# Pass remaining arguments to base class
386373
*args,
387374
**kwargs,
@@ -402,21 +389,16 @@ If your task requires even more granular control over the policy/value architect
402389
Off-Policy Algorithms
403390
^^^^^^^^^^^^^^^^^^^^^
404391

405-
If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
406-
you can pass a dictionary of the following structure: ``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
392+
If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG``, ``TQC`` or ``TD3``,
393+
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], qf=[<critic network architecture>])``.
407394

408395
For example, if you want a different architecture for the actor (aka ``pi``) and the critic (Q-function aka ``qf``) networks,
409-
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
396+
then you can specify ``net_arch=dict(pi=[64, 64], qf=[400, 300])``.
410397

411398
Otherwise, to have actor and critic that share the same network architecture,
412399
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
413400

414401

415-
.. note::
416-
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
417-
between the actor and the critic are allowed (to prevent issues with target networks).
418-
419-
420402
.. note::
421403
For advanced customization of off-policy algorithms policies, please take a look at the code.
422404
A good understanding of the algorithm used is required, see discussion in `issue #425 <https://github.com/DLR-RM/stable-baselines3/issues/425>`_

docs/misc/changelog.rst

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@ Changelog
44
==========
55

66

7-
Release 1.7.0a11 (WIP)
7+
Release 1.7.0a12 (WIP)
88
--------------------------
99

10+
.. warning::
11+
12+
Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO.
13+
This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]``
14+
will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms.
15+
16+
1017
.. note::
1118

1219
A2C and PPO saved with SB3 < 1.7.0 will show a warning about
@@ -34,8 +41,15 @@ New Features:
3441
- Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor``
3542
- Added support for Python 3.10
3643

37-
SB3-Contrib
38-
^^^^^^^^^^^
44+
`SB3-Contrib`_
45+
^^^^^^^^^^^^^^
46+
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
47+
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``
48+
49+
`RL Zoo`_
50+
^^^^^^^^^
51+
- Added support for python file for configuration
52+
- Added ``monitor_kwargs`` parameter
3953

4054
Bug Fixes:
4155
^^^^^^^^^^
@@ -52,6 +66,7 @@ Bug Fixes:
5266
Deprecations:
5367
^^^^^^^^^^^^^
5468
- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()``
69+
- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua)
5570

5671
Others:
5772
^^^^^^^
@@ -99,8 +114,12 @@ New Features:
99114
- Added progress bar callback
100115
- The `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ can now be installed as a package (``pip install rl_zoo3``)
101116

102-
SB3-Contrib
103-
^^^^^^^^^^^
117+
`SB3-Contrib`_
118+
^^^^^^^^^^^^^^
119+
120+
`RL Zoo`_
121+
^^^^^^^^^
122+
- RL Zoo is now a python package and can be installed using ``pip install rl_zoo3``
104123

105124
Bug Fixes:
106125
^^^^^^^^^^
@@ -135,8 +154,8 @@ New Features:
135154
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
136155
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
137156

138-
SB3-Contrib
139-
^^^^^^^^^^^
157+
`SB3-Contrib`_
158+
^^^^^^^^^^^^^^
140159
- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel)
141160

142161
Bug Fixes:
@@ -192,8 +211,8 @@ Breaking Changes:
192211
New Features:
193212
^^^^^^^^^^^^^
194213

195-
SB3-Contrib
196-
^^^^^^^^^^^
214+
`SB3-Contrib`_
215+
^^^^^^^^^^^^^^
197216
- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
198217

199218

@@ -246,8 +265,8 @@ New Features:
246265
depending on desired maximum width of output.
247266
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
248267

249-
SB3-Contrib
250-
^^^^^^^^^^^
268+
`SB3-Contrib`_
269+
^^^^^^^^^^^^^^
251270
- coming soon: Cross Entropy Method, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/62
252271

253272
Bug Fixes:
@@ -309,8 +328,8 @@ New Features:
309328
- Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong
310329
- Added ``copy()`` and ``combine()`` methods to ``RunningMeanStd``
311330

312-
SB3-Contrib
313-
^^^^^^^^^^^
331+
`SB3-Contrib`_
332+
^^^^^^^^^^^^^^
314333
- Added Trust Region Policy Optimization (TRPO) (@cyprienc)
315334
- Added Augmented Random Search (ARS) (@sgillen)
316335
- Coming soon: PPO LSTM, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
@@ -1137,7 +1156,8 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).
11371156
.. _Quentin Gallouédec: https://gallouedec.com/
11381157
.. _@qgallouedec: https://github.com/qgallouedec
11391158

1140-
1159+
.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
1160+
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
11411161

11421162
Contributors:
11431163
-------------

stable_baselines3/common/base_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def set_parameters(
617617
f"expected {objects_needing_update}, got {updated_objects}"
618618
)
619619

620-
@classmethod
620+
@classmethod # noqa: C901
621621
def load(
622622
cls: Type[SelfBaseAlgorithm],
623623
path: Union[str, pathlib.Path, io.BufferedIOBase],

stable_baselines3/common/envs/identity_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from gym import spaces
66

7-
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
7+
from stable_baselines3.common.type_aliases import GymStepReturn
88

99
T = TypeVar("T", int, np.ndarray)
1010

stable_baselines3/common/policies.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def __init__(
418418
observation_space: spaces.Space,
419419
action_space: spaces.Space,
420420
lr_schedule: Schedule,
421-
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
421+
# TODO(antonin): update type annotation when we remove shared network support
422+
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
422423
activation_fn: Type[nn.Module] = nn.Tanh,
423424
ortho_init: bool = True,
424425
use_sde: bool = False,
@@ -451,12 +452,28 @@ def __init__(
451452
normalize_images=normalize_images,
452453
)
453454

455+
# Convert [dict()] to dict() as shared network are deprecated
456+
if isinstance(net_arch, list) and len(net_arch) > 0:
457+
if isinstance(net_arch[0], dict):
458+
warnings.warn(
459+
(
460+
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
461+
"you should now pass directly a dictionary and not a list "
462+
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
463+
),
464+
)
465+
net_arch = net_arch[0]
466+
else:
467+
# Note: deprecation warning will be emitted
468+
# by the MlpExtractor constructor
469+
pass
470+
454471
# Default network architecture, from stable-baselines
455472
if net_arch is None:
456473
if features_extractor_class == NatureCNN:
457474
net_arch = []
458475
else:
459-
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
476+
net_arch = dict(pi=[64, 64], vf=[64, 64])
460477

461478
self.net_arch = net_arch
462479
self.activation_fn = activation_fn
@@ -472,7 +489,8 @@ def __init__(
472489
self.pi_features_extractor = self.features_extractor
473490
self.vf_features_extractor = self.make_features_extractor()
474491
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
475-
if len(net_arch) > 0 and not isinstance(net_arch[0], dict):
492+
# TODO(antonin): update the check once we change net_arch behavior
493+
if isinstance(net_arch, list) and len(net_arch) > 0:
476494
raise ValueError(
477495
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
478496
)
@@ -752,7 +770,7 @@ def __init__(
752770
observation_space: spaces.Space,
753771
action_space: spaces.Space,
754772
lr_schedule: Schedule,
755-
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
773+
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
756774
activation_fn: Type[nn.Module] = nn.Tanh,
757775
ortho_init: bool = True,
758776
use_sde: bool = False,
@@ -825,7 +843,7 @@ def __init__(
825843
observation_space: spaces.Dict,
826844
action_space: spaces.Space,
827845
lr_schedule: Schedule,
828-
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
846+
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
829847
activation_fn: Type[nn.Module] = nn.Tanh,
830848
ortho_init: bool = True,
831849
use_sde: bool = False,

0 commit comments

Comments
 (0)