Skip to content

Commit 7ce7b6a

Browse files
authored
Update defaults for offpolicy algos with features extractor (#935)
1 parent d68f0a2 commit 7ce7b6a

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ Changelog
44
==========
55

66

7-
Release 1.5.1a8 (WIP)
7+
Release 1.5.1a9 (WIP)
88
---------------------------
99

1010
Breaking Changes:
1111
^^^^^^^^^^^^^^^^^
1212
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
1313
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
1414
- SB3 now requires PyTorch >= 1.11
15+
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3,
16+
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
1517

1618
New Features:
1719
^^^^^^^^^^^^^

stable_baselines3/sac/policies.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def __init__(
235235
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
236236
optimizer_kwargs: Optional[Dict[str, Any]] = None,
237237
n_critics: int = 2,
238-
share_features_extractor: bool = True,
238+
share_features_extractor: bool = False,
239239
):
240240
super().__init__(
241241
observation_space,
@@ -248,10 +248,7 @@ def __init__(
248248
)
249249

250250
if net_arch is None:
251-
if features_extractor_class == NatureCNN:
252-
net_arch = []
253-
else:
254-
net_arch = [256, 256]
251+
net_arch = [256, 256]
255252

256253
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
257254

@@ -422,7 +419,7 @@ def __init__(
422419
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
423420
optimizer_kwargs: Optional[Dict[str, Any]] = None,
424421
n_critics: int = 2,
425-
share_features_extractor: bool = True,
422+
share_features_extractor: bool = False,
426423
):
427424
super().__init__(
428425
observation_space,
@@ -493,7 +490,7 @@ def __init__(
493490
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
494491
optimizer_kwargs: Optional[Dict[str, Any]] = None,
495492
n_critics: int = 2,
496-
share_features_extractor: bool = True,
493+
share_features_extractor: bool = False,
497494
):
498495
super().__init__(
499496
observation_space,

stable_baselines3/td3/policies.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(
119119
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
120120
optimizer_kwargs: Optional[Dict[str, Any]] = None,
121121
n_critics: int = 2,
122-
share_features_extractor: bool = True,
122+
share_features_extractor: bool = False,
123123
):
124124
super().__init__(
125125
observation_space,
@@ -134,7 +134,7 @@ def __init__(
134134
# Default network architecture, from the original paper
135135
if net_arch is None:
136136
if features_extractor_class == NatureCNN:
137-
net_arch = []
137+
net_arch = [256, 256]
138138
else:
139139
net_arch = [400, 300]
140140

@@ -281,7 +281,7 @@ def __init__(
281281
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
282282
optimizer_kwargs: Optional[Dict[str, Any]] = None,
283283
n_critics: int = 2,
284-
share_features_extractor: bool = True,
284+
share_features_extractor: bool = False,
285285
):
286286
super().__init__(
287287
observation_space,
@@ -335,7 +335,7 @@ def __init__(
335335
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
336336
optimizer_kwargs: Optional[Dict[str, Any]] = None,
337337
n_critics: int = 2,
338-
share_features_extractor: bool = True,
338+
share_features_extractor: bool = False,
339339
):
340340
super().__init__(
341341
observation_space,

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.5.1a8
1+
1.5.1a9

0 commit comments

Comments
 (0)