Skip to content

Commit 9caa168

Browse files
kplersaraffin
andauthored
Add policy documentation links to policy_kwargs parameter (#2050)
* docs: Add policy documentation links to policy_kwargs parameter * Fix missing references, update changelog --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 897d01d commit 9caa168

File tree

9 files changed

+16
-11
lines changed

9 files changed

+16
-11
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Documentation:
3838
^^^^^^^^^^^^^^
3939
- Added Decisions and Dragons to resources. (@jmacglashan)
4040
- Updated PyBullet example, now compatible with Gymnasium
41+
- Added link to policies for ``policy_kwargs`` parameter (@kplers)
4142

4243
Release 2.4.0 (2024-11-18)
4344
--------------------------
@@ -1738,4 +1739,4 @@ And all the contributors:
17381739
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
17391740
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
17401741
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
1741-
@brn-dev @jmacglashan
1742+
@brn-dev @jmacglashan @kplers

docs/modules/a2c.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
7878

7979
A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:
8080

81-
.. code-block::
81+
.. code-block:: python
8282
8383
from stable_baselines3 import A2C
8484
from stable_baselines3.common.env_util import make_vec_env
@@ -88,7 +88,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
8888
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
8989
model = A2C("MlpPolicy", env, device="cpu")
9090
model.learn(total_timesteps=25_000)
91-
91+
9292
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.
9393

9494

@@ -165,6 +165,8 @@ Parameters
165165
:inherited-members:
166166

167167

168+
.. _a2c_policies:
169+
168170
A2C Policies
169171
-------------
170172

docs/modules/ppo.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
9292

9393
PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:
9494

95-
.. code-block::
95+
.. code-block:: python
9696
9797
from stable_baselines3 import PPO
9898
from stable_baselines3.common.env_util import make_vec_env
@@ -102,7 +102,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
102102
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
103103
model = PPO("MlpPolicy", env, device="cpu")
104104
model.learn(total_timesteps=25_000)
105-
105+
106106
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.
107107

108108
Results
@@ -178,6 +178,8 @@ Parameters
178178
:inherited-members:
179179

180180

181+
.. _ppo_policies:
182+
181183
PPO Policies
182184
-------------
183185

stable_baselines3/a2c/a2c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class A2C(OnPolicyAlgorithm):
4848
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
4949
the reported success rate, mean episode length, and mean reward over
5050
:param tensorboard_log: the log location for tensorboard (if None, no logging)
51-
:param policy_kwargs: additional arguments to be passed to the policy on creation
51+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies`
5252
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
5353
debug messages
5454
:param seed: Seed for the pseudo random generators

stable_baselines3/ddpg/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class DDPG(TD3):
4444
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4545
at a cost of more complexity.
4646
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
47-
:param policy_kwargs: additional arguments to be passed to the policy on creation
47+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
4848
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
4949
debug messages
5050
:param seed: Seed for the pseudo random generators

stable_baselines3/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class DQN(OffPolicyAlgorithm):
5353
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5454
the reported success rate, mean episode length, and mean reward over
5555
:param tensorboard_log: the log location for tensorboard (if None, no logging)
56-
:param policy_kwargs: additional arguments to be passed to the policy on creation
56+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies`
5757
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
5858
debug messages
5959
:param seed: Seed for the pseudo random generators

stable_baselines3/ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class PPO(OnPolicyAlgorithm):
6262
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
6363
the reported success rate, mean episode length, and mean reward over
6464
:param tensorboard_log: the log location for tensorboard (if None, no logging)
65-
:param policy_kwargs: additional arguments to be passed to the policy on creation
65+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies`
6666
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
6767
debug messages
6868
:param seed: Seed for the pseudo random generators

stable_baselines3/sac/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class SAC(OffPolicyAlgorithm):
6868
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
6969
the reported success rate, mean episode length, and mean reward over
7070
:param tensorboard_log: the log location for tensorboard (if None, no logging)
71-
:param policy_kwargs: additional arguments to be passed to the policy on creation
71+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies`
7272
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
7373
debug messages
7474
:param seed: Seed for the pseudo random generators

stable_baselines3/td3/td3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class TD3(OffPolicyAlgorithm):
5656
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5757
the reported success rate, mean episode length, and mean reward over
5858
:param tensorboard_log: the log location for tensorboard (if None, no logging)
59-
:param policy_kwargs: additional arguments to be passed to the policy on creation
59+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies`
6060
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
6161
debug messages
6262
:param seed: Seed for the pseudo random generators

0 commit comments

Comments
 (0)