@@ -51,8 +51,6 @@ Each of these network have a features extractor followed by a fully-connected ne
51
51
.. image :: ../_static/img/sb3_policy.png
52
52
53
53
54
- .. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
55
-
56
54
57
55
Custom Network Architecture
58
56
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -90,13 +88,13 @@ using ``policy_kwargs`` parameter:
90
88
# of two layers of size 32 each with Relu activation function
91
89
# Note: an extra linear layer will be added on top of the pi and the vf nets, respectively
92
90
policy_kwargs = dict (activation_fn = th.nn.ReLU,
93
- net_arch = [ dict (pi = [32 , 32 ], vf = [32 , 32 ])] )
91
+ net_arch = dict (pi = [32 , 32 ], vf = [32 , 32 ]))
94
92
# Create the agent
95
93
model = PPO(" MlpPolicy" , " CartPole-v1" , policy_kwargs = policy_kwargs, verbose = 1 )
96
94
# Retrieve the environment
97
95
env = model.get_env()
98
96
# Train the agent
99
- model.learn(total_timesteps = 100000 )
97
+ model.learn(total_timesteps = 20_000 )
100
98
# Save the agent
101
99
model.save(" ppo_cartpole" )
102
100
@@ -114,13 +112,14 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
114
112
115
113
.. note ::
116
114
117
- By default the features extractor is shared between the actor and the critic to save computation (when applicable).
115
+ For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable).
118
116
However, this can be changed setting ``share_features_extractor=False `` in the
119
117
``policy_kwargs `` (both for on-policy and off-policy algorithms).
120
118
121
119
122
120
.. warning ::
123
121
If the features extractor is **non-shared **, it is **not ** possible to have shared layers in the ``mlp_extractor ``.
122
+ Please note that this option is **deprecated **, therefore in a future release the layers in the ``mlp_extractor `` will have to be non-shared.
124
123
125
124
126
125
.. code-block :: python
@@ -240,64 +239,56 @@ downsampling and "vector" with a single linear layer.
240
239
On-Policy Algorithms
241
240
^^^^^^^^^^^^^^^^^^^^
242
241
243
- Shared Networks
242
+ Custom Networks
244
243
---------------
245
244
246
- The ``net_arch `` parameter of ``A2C `` and ``PPO `` policies allows to specify the amount and size of the hidden layers and how many
247
- of them are shared between the policy network and the value network. It is assumed to be a list with the following
248
- structure:
249
-
250
- 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
251
- If the number of ints is zero, there will be no shared layers.
252
- 2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
253
- It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>]) ``.
254
- If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
255
-
256
- In short: ``[<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])] ``.
257
-
258
- Examples
259
- ~~~~~~~~
260
-
261
- Two shared layers of size 128: ``net_arch=[128, 128] ``
262
-
263
-
264
- .. code-block :: none
245
+ .. warning ::
246
+ Shared layers in the the ``mlp_extractor `` are **deprecated **.
247
+ In a future release all layers will have to be non-shared.
248
+ If needed, you can implement a custom policy network (see `advanced example below <#advanced-example >`_).
265
249
266
- obs
267
- |
268
- <128>
269
- |
270
- <128>
271
- / \
272
- action value
250
+ .. warning ::
251
+ In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128] `` will change
252
+ to match the one of off-policy algorithms: it will create **separate ** networks (instead of shared currently)
253
+ for the actor and the critic, with the same architecture.
273
254
274
255
275
- Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])] ``
256
+ If you need a network architecture that is different for the actor and the critic when using ``PPO ``, ``A2C `` or ``TRPO ``,
257
+ you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>]) ``.
276
258
277
- .. code-block :: none
259
+ For example, if you want a different architecture for the actor (aka ``pi ``) and the critic ( value-function aka ``vf ``) networks,
260
+ then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64]) ``.
278
261
279
- obs
280
- |
281
- <128>
282
- / \
283
- action <256>
284
- |
285
- <256>
286
- |
287
- value
262
+ .. Otherwise, to have actor and critic that share the same network architecture,
263
+ .. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
288
264
265
+ Examples
266
+ ~~~~~~~~
289
267
290
- Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])] ``
268
+ .. TODO(antonin): uncomment when shared network is removed
269
+ .. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
270
+ ..
271
+ .. .. code-block:: none
272
+ ..
273
+ .. obs
274
+ .. / \
275
+ .. <128> <128>
276
+ .. | |
277
+ .. <128> <128>
278
+ .. | |
279
+ .. action value
280
+
281
+ Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64]) ``
291
282
292
283
.. code-block :: none
293
284
294
- obs
295
- |
296
- <128 >
297
- / \
298
- <16 > <256 >
299
- | |
300
- action value
285
+ obs
286
+ / \
287
+ <32> <64 >
288
+ | |
289
+ <32 > <64 >
290
+ | |
291
+ action value
301
292
302
293
303
294
Advanced Example
@@ -334,7 +325,7 @@ If your task requires even more granular control over the policy/value architect
334
325
last_layer_dim_pi : int = 64 ,
335
326
last_layer_dim_vf : int = 64 ,
336
327
):
337
- super (CustomNetwork, self ).__init__ ()
328
+ super ().__init__ ()
338
329
339
330
# IMPORTANT:
340
331
# Save output dimensions, used to create the distributions
@@ -370,8 +361,6 @@ If your task requires even more granular control over the policy/value architect
370
361
observation_space : spaces.Space,
371
362
action_space : spaces.Space,
372
363
lr_schedule : Callable[[float ], float ],
373
- net_arch : Optional[List[Union[int , Dict[str , List[int ]]]]] = None ,
374
- activation_fn : Type[nn.Module] = nn.Tanh,
375
364
* args ,
376
365
** kwargs ,
377
366
):
@@ -380,8 +369,6 @@ If your task requires even more granular control over the policy/value architect
380
369
observation_space,
381
370
action_space,
382
371
lr_schedule,
383
- net_arch,
384
- activation_fn,
385
372
# Pass remaining arguments to base class
386
373
* args,
387
374
** kwargs,
@@ -402,21 +389,16 @@ If your task requires even more granular control over the policy/value architect
402
389
Off-Policy Algorithms
403
390
^^^^^^^^^^^^^^^^^^^^^
404
391
405
- If you need a network architecture that is different for the actor and the critic when using ``SAC ``, ``DDPG `` or ``TD3 ``,
406
- you can pass a dictionary of the following structure: ``dict(qf =[<critic network architecture>], pi =[<actor network architecture>]) ``.
392
+ If you need a network architecture that is different for the actor and the critic when using ``SAC ``, ``DDPG ``, `` TQC `` or ``TD3 ``,
393
+ you can pass a dictionary of the following structure: ``dict(pi =[<actor network architecture>], qf =[<critic network architecture>]) ``.
407
394
408
395
For example, if you want a different architecture for the actor (aka ``pi ``) and the critic (Q-function aka ``qf ``) networks,
409
- then you can specify ``net_arch=dict(qf=[400, 300 ], pi=[64, 64 ]) ``.
396
+ then you can specify ``net_arch=dict(pi=[64, 64 ], qf=[400, 300 ]) ``.
410
397
411
398
Otherwise, to have actor and critic that share the same network architecture,
412
399
you only need to specify ``net_arch=[256, 256] `` (here, two hidden layers of 256 units each).
413
400
414
401
415
- .. note ::
416
- Compared to their on-policy counterparts, no shared layers (other than the features extractor)
417
- between the actor and the critic are allowed (to prevent issues with target networks).
418
-
419
-
420
402
.. note ::
421
403
For advanced customization of off-policy algorithms policies, please take a look at the code.
422
404
A good understanding of the algorithm used is required, see discussion in `issue #425 <https://github.com/DLR-RM/stable-baselines3/issues/425 >`_
0 commit comments