Skip to content

Commit e1ca24a

Browse files
kplersaraffin
andauthored
Add policy documentation links to policy_kwargs parameter (Stable-Baselines-Team#266)
* Add policy documentation links to policy_kwargs parameter * Sort `__all__` --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 36c21ac commit e1ca24a

File tree

15 files changed

+17
-15
lines changed

15 files changed

+17
-15
lines changed

docs/modules/ppo_mask.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ Parameters
245245
:members:
246246
:inherited-members:
247247

248+
.. _ppo_mask_policies:
248249

249250
MaskablePPO Policies
250251
--------------------

docs/modules/ppo_recurrent.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ Parameters
125125
:members:
126126
:inherited-members:
127127

128+
.. _ppo_recurrent_policies:
128129

129130
RecurrentPPO Policies
130131
---------------------

sb3_contrib/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
__all__ = [
1717
"ARS",
18-
"CrossQ",
19-
"MaskablePPO",
20-
"RecurrentPPO",
2118
"QRDQN",
2219
"TQC",
2320
"TRPO",
21+
"CrossQ",
22+
"MaskablePPO",
23+
"RecurrentPPO",
2424
]

sb3_contrib/ars/ars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ARS(BaseAlgorithm):
4040
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
4141
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
4242
:param n_eval_episodes: Number of episodes to evaluate each candidate.
43-
:param policy_kwargs: Keyword arguments to pass to the policy on creation
43+
:param policy_kwargs: Keyword arguments to pass to the policy on creation. See :ref:`ars_policies`
4444
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
4545
the reported success rate, mean episode length, and mean reward over
4646
:param tensorboard_log: String with the directory to put tensorboard logs:

sb3_contrib/common/torch_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
__all__ = ["BatchRenorm1d", "BatchRenorm"]
3+
__all__ = ["BatchRenorm", "BatchRenorm1d"]
44

55

66
class BatchRenorm(torch.nn.Module):

sb3_contrib/crossq/crossq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class CrossQ(OffPolicyAlgorithm):
5656
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5757
the reported success rate, mean episode length, and mean reward over
5858
:param tensorboard_log: the log location for tensorboard (if None, no logging)
59-
:param policy_kwargs: additional arguments to be passed to the policy on creation
59+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`crossq_policies`
6060
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
6161
debug messages
6262
:param seed: Seed for the pseudo random generators

sb3_contrib/ppo_mask/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
22
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
33

4-
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"]
4+
__all__ = ["CnnPolicy", "MaskablePPO", "MlpPolicy", "MultiInputPolicy"]

sb3_contrib/ppo_mask/ppo_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class MaskablePPO(OnPolicyAlgorithm):
5757
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5858
the reported success rate, mean episode length, and mean reward over
5959
:param tensorboard_log: the log location for tensorboard (if None, no logging)
60-
:param policy_kwargs: additional arguments to be passed to the policy on creation
60+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_mask_policies`
6161
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
6262
:param seed: Seed for the pseudo random generators
6363
:param device: Device (cpu, cuda, ...) on which the code should be run.

sb3_contrib/ppo_recurrent/ppo_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
5757
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5858
the reported success rate, mean episode length, and mean reward over
5959
:param tensorboard_log: the log location for tensorboard (if None, no logging)
60-
:param policy_kwargs: additional arguments to be passed to the policy on creation
60+
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_recurrent_policies`
6161
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
6262
:param seed: Seed for the pseudo random generators
6363
:param device: Device (cpu, cuda, ...) on which the code should be run.

sb3_contrib/qrdqn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
22
from sb3_contrib.qrdqn.qrdqn import QRDQN
33

4-
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"]
4+
__all__ = ["QRDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]

0 commit comments

Comments
 (0)