Skip to content

Commit b702884

Browse files
AlexPasquaaraffin
andauthored
Removed shared layers in mlp_extractor (#1292)
* Modified actor-critic policies & MlpExtractor class ActorCriticPolicy: - changed type hint of net_arch param: now it's a dict - removed check that if features extractor is not shared: no shared layers are allowed in the mlp_extractor regardless of the features extractor ActorCriticCnnPolicy: - changed type hint of net_arch param: now it's a dict MultiInputActorcriticPolicy: - changed type hint of net_arch param: now it's a dict MlpExtractor: - changed type hint of net_arch param: now it's a dict - adapted networks creation - adapted methods: forward, forward_actor & forward_critic * Removed shared layers in mlp_extractor * Updated docs and changelog + reformat * Updated custom policy tests * Removed test on deprecation warning for share layers in mlp_extractor Now shared layers are removed * Update version * Update RL Zoo doc * Fix linter warnings * Add ruff to Makefile (experimental) * Add backward compat code and minor updates * Update tests * Add backward compatibility * Fix test * Improve compat code Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 69fdf15 commit b702884

File tree

11 files changed

+123
-163
lines changed

11 files changed

+123
-163
lines changed

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ lint:
1919
# exit-zero treats all errors as warnings.
2020
flake8 ${LINT_PATHS} --count --exit-zero --statistics
2121

22+
ruff:
23+
# stop the build if there are Python syntax errors or undefined names
24+
# see https://lintlyci.github.io/Flake8Rules/
25+
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
26+
# exit-zero treats all errors as warnings.
27+
ruff ${LINT_PATHS} --exit-zero --line-length 127
28+
2229
format:
2330
# Sort imports
2431
isort ${LINT_PATHS}

docs/guide/custom_policy.rst

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
117117
``policy_kwargs`` (both for on-policy and off-policy algorithms).
118118

119119

120-
.. warning::
121-
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
122-
Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared.
123-
124-
125120
.. code-block:: python
126121
127122
import torch as th
@@ -242,41 +237,31 @@ On-Policy Algorithms
242237
Custom Networks
243238
---------------
244239

245-
.. warning::
246-
Shared layers in the the ``mlp_extractor`` are **deprecated**.
247-
In a future release all layers will have to be non-shared.
248-
If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_).
249-
250-
.. warning::
251-
In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change
252-
to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently)
253-
for the actor and the critic, with the same architecture.
254-
255-
256240
If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``,
257241
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>])``.
258242

259243
For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks,
260244
then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``.
261245

262-
.. Otherwise, to have actor and critic that share the same network architecture,
263-
.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
246+
Otherwise, to have actor and critic that share the same network architecture,
247+
you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each, this is equivalent to ``net_arch=dict(pi=[128, 128], vf=[128, 128])``).
248+
249+
If shared layers are needed, you need to implement a custom policy network (see `advanced example below <#advanced-example>`_).
264250

265251
Examples
266252
~~~~~~~~
267253

268-
.. TODO(antonin): uncomment when shared network is removed
269-
.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
270-
..
271-
.. .. code-block:: none
272-
..
273-
.. obs
274-
.. / \
275-
.. <128> <128>
276-
.. | |
277-
.. <128> <128>
278-
.. | |
279-
.. action value
254+
Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
255+
256+
.. code-block:: none
257+
258+
obs
259+
/ \
260+
<128> <128>
261+
| |
262+
<128> <128>
263+
| |
264+
action value
280265
281266
Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])``
282267

docs/guide/rl_zoo.rst

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ Goals of this repository:
2020
Installation
2121
------------
2222

23+
Option 1: install the python package ``pip install rl_zoo3``
24+
25+
or:
26+
2327
1. Clone the repository:
2428

2529
::
@@ -42,7 +46,10 @@ Installation
4246
::
4347

4448
apt-get install swig cmake ffmpeg
49+
# full dependencies
4550
pip install -r requirements.txt
51+
# minimal dependencies
52+
pip install -e .
4653

4754

4855
Train an Agent
@@ -56,21 +63,21 @@ using:
5663

5764
::
5865

59-
python train.py --algo algo_name --env env_id
66+
python -m rl_zoo3.train --algo algo_name --env env_id
6067

6168
For example (with evaluation and checkpoints):
6269

6370
::
6471

65-
python train.py --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
72+
python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
6673

6774

6875
Continue training (here, load pretrained agent for Breakout and continue
6976
training for 5000 steps):
7077

7178
::
7279

73-
python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
80+
python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
7481

7582

7683
Enjoy a Trained Agent
@@ -80,13 +87,13 @@ If the trained agent exists, then you can see it in action using:
8087

8188
::
8289

83-
python enjoy.py --algo algo_name --env env_id
90+
python -m rl_zoo3.enjoy --algo algo_name --env env_id
8491

8592
For example, enjoy A2C on Breakout during 5000 timesteps:
8693

8794
::
8895

89-
python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
96+
python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
9097

9198

9299
Hyperparameter Optimization
@@ -100,7 +107,7 @@ with a budget of 1000 trials and a maximum of 50000 steps:
100107

101108
::
102109

103-
python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
110+
python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
104111
--sampler random --pruner median
105112

106113

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ Changelog
44
==========
55

66

7-
Release 1.8.0a1 (WIP)
7+
Release 1.8.0a2 (WIP)
88
--------------------------
99

1010

1111
Breaking Changes:
1212
^^^^^^^^^^^^^^^^^
13+
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
1314

1415
New Features:
1516
^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,11 @@ def load(
667667
if "policy_kwargs" in data:
668668
if "device" in data["policy_kwargs"]:
669669
del data["policy_kwargs"]["device"]
670+
# backward compatibility, convert to new format
671+
if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
672+
saved_net_arch = data["policy_kwargs"]["net_arch"]
673+
if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
674+
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
670675

671676
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
672677
raise ValueError(
@@ -726,7 +731,6 @@ def load(
726731
)
727732
else:
728733
raise e
729-
730734
# put other pytorch variables back in place
731735
if pytorch_variables is not None:
732736
for name in pytorch_variables:

stable_baselines3/common/buffers.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
474474
yield self._get_samples(indices[start_idx : start_idx + batch_size])
475475
start_idx += batch_size
476476

477-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
477+
def _get_samples(
478+
self,
479+
batch_inds: np.ndarray,
480+
env: Optional[VecNormalize] = None,
481+
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
478482
data = (
479483
self.observations[batch_inds],
480484
self.actions[batch_inds],
@@ -603,7 +607,11 @@ def add(
603607
self.full = True
604608
self.pos = 0
605609

606-
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
610+
def sample(
611+
self,
612+
batch_size: int,
613+
env: Optional[VecNormalize] = None,
614+
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
607615
"""
608616
Sample elements from the replay buffer.
609617
@@ -614,7 +622,11 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictRep
614622
"""
615623
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
616624

617-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
625+
def _get_samples(
626+
self,
627+
batch_inds: np.ndarray,
628+
env: Optional[VecNormalize] = None,
629+
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
618630
# Sample randomly the env idx
619631
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
620632

@@ -743,7 +755,10 @@ def add(
743755
if self.pos == self.buffer_size:
744756
self.full = True
745757

746-
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
758+
def get(
759+
self,
760+
batch_size: Optional[int] = None,
761+
) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
747762
assert self.full, ""
748763
indices = np.random.permutation(self.buffer_size * self.n_envs)
749764
# Prepare the data
@@ -767,7 +782,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSa
767782
yield self._get_samples(indices[start_idx : start_idx + batch_size])
768783
start_idx += batch_size
769784

770-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
785+
def _get_samples(
786+
self,
787+
batch_inds: np.ndarray,
788+
env: Optional[VecNormalize] = None,
789+
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
771790

772791
return DictRolloutBufferSamples(
773792
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},

stable_baselines3/common/policies.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ def __init__(
418418
observation_space: spaces.Space,
419419
action_space: spaces.Space,
420420
lr_schedule: Schedule,
421-
# TODO(antonin): update type annotation when we remove shared network support
422-
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
421+
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
423422
activation_fn: Type[nn.Module] = nn.Tanh,
424423
ortho_init: bool = True,
425424
use_sde: bool = False,
@@ -452,21 +451,15 @@ def __init__(
452451
normalize_images=normalize_images,
453452
)
454453

455-
# Convert [dict()] to dict() as shared network are deprecated
456-
if isinstance(net_arch, list) and len(net_arch) > 0:
457-
if isinstance(net_arch[0], dict):
458-
warnings.warn(
459-
(
460-
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
461-
"you should now pass directly a dictionary and not a list "
462-
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
463-
),
464-
)
465-
net_arch = net_arch[0]
466-
else:
467-
# Note: deprecation warning will be emitted
468-
# by the MlpExtractor constructor
469-
pass
454+
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
455+
warnings.warn(
456+
(
457+
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
458+
"you should now pass directly a dictionary and not a list "
459+
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
460+
),
461+
)
462+
net_arch = net_arch[0]
470463

471464
# Default network architecture, from stable-baselines
472465
if net_arch is None:
@@ -488,12 +481,6 @@ def __init__(
488481
else:
489482
self.pi_features_extractor = self.features_extractor
490483
self.vf_features_extractor = self.make_features_extractor()
491-
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
492-
# TODO(antonin): update the check once we change net_arch behavior
493-
if isinstance(net_arch, list) and len(net_arch) > 0:
494-
raise ValueError(
495-
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
496-
)
497484

498485
self.log_std_init = log_std_init
499486
dist_kwargs = None
@@ -770,7 +757,7 @@ def __init__(
770757
observation_space: spaces.Space,
771758
action_space: spaces.Space,
772759
lr_schedule: Schedule,
773-
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
760+
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
774761
activation_fn: Type[nn.Module] = nn.Tanh,
775762
ortho_init: bool = True,
776763
use_sde: bool = False,
@@ -843,7 +830,7 @@ def __init__(
843830
observation_space: spaces.Dict,
844831
action_space: spaces.Space,
845832
lr_schedule: Schedule,
846-
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
833+
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
847834
activation_fn: Type[nn.Module] = nn.Tanh,
848835
ortho_init: bool = True,
849836
use_sde: bool = False,

0 commit comments

Comments
 (0)