Skip to content

Commit bec0038

Browse files
authored
Upgrade to python 3.7+ syntax (#69)
* Upgrade to python 3.7+ syntax * Switch to PyTorch 1.11
1 parent 812648e commit bec0038

File tree

17 files changed

+39
-40
lines changed

17 files changed

+39
-40
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
# cpu version of pytorch
31-
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
31+
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
3232
# Install dependencies for docs and tests
3333
pip install stable_baselines3[extra,tests,docs]
3434
# Install master version

docs/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
#
32
# Configuration file for the Sphinx documentation builder.
43
#
@@ -46,7 +45,7 @@ def __getattr__(cls, name):
4645

4746
# Read version from file
4847
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
49-
with open(version_file, "r") as file_handler:
48+
with open(version_file) as file_handler:
5049
__version__ = file_handler.read().strip()
5150

5251
# -- Project information -----------------------------------------------------

docs/misc/changelog.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
Changelog
44
==========
55

6-
Release 1.5.1a1 (WIP)
6+
Release 1.5.1a5 (WIP)
77
-------------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
11-
- Upgraded to Stable-Baselines3 >= 1.5.1a1
11+
- Upgraded to Stable-Baselines3 >= 1.5.1a5
1212
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
1313
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
1414
- Renamed ``rollout/exploration rate`` key to ``rollout/exploration_rate`` for QRDQN (to be consistent with SB3 DQN)
15+
- Upgraded to python 3.7+ syntax using ``pyupgrade``
16+
- SB3 now requires PyTorch >= 1.11
1517

1618
New Features:
1719
^^^^^^^^^^^^^

sb3_contrib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88

99
# Read version from file
1010
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
11-
with open(version_file, "r") as file_handler:
11+
with open(version_file) as file_handler:
1212
__version__ = file_handler.read().strip()

sb3_contrib/common/maskable/buffers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@ def __init__(
145145
n_envs: int = 1,
146146
):
147147
self.action_masks = None
148-
super(MaskableDictRolloutBuffer, self).__init__(
149-
buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs
150-
)
148+
super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs)
151149

152150
def reset(self) -> None:
153151
if isinstance(self.action_space, spaces.Discrete):
@@ -162,7 +160,7 @@ def reset(self) -> None:
162160
self.mask_dims = mask_dims
163161
self.action_masks = np.ones((self.buffer_size, self.n_envs, self.mask_dims), dtype=np.float32)
164162

165-
super(MaskableDictRolloutBuffer, self).reset()
163+
super().reset()
166164

167165
def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> None:
168166
"""
@@ -171,7 +169,7 @@ def add(self, *args, action_masks: Optional[np.ndarray] = None, **kwargs) -> Non
171169
if action_masks is not None:
172170
self.action_masks[self.pos] = action_masks.reshape((self.n_envs, self.mask_dims))
173171

174-
super(MaskableDictRolloutBuffer, self).add(*args, **kwargs)
172+
super().add(*args, **kwargs)
175173

176174
def get(self, batch_size: Optional[int] = None) -> Generator[MaskableDictRolloutBufferSamples, None, None]:
177175
assert self.full, ""

sb3_contrib/common/maskable/policies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def __init__(
345345
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
346346
optimizer_kwargs: Optional[Dict[str, Any]] = None,
347347
):
348-
super(MaskableActorCriticCnnPolicy, self).__init__(
348+
super().__init__(
349349
observation_space,
350350
action_space,
351351
lr_schedule,
@@ -396,7 +396,7 @@ def __init__(
396396
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
397397
optimizer_kwargs: Optional[Dict[str, Any]] = None,
398398
):
399-
super(MaskableMultiInputActorCriticPolicy, self).__init__(
399+
super().__init__(
400400
observation_space,
401401
action_space,
402402
lr_schedule,

sb3_contrib/common/vec_env/async_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _worker(
7272
break
7373

7474

75-
class AsyncEval(object):
75+
class AsyncEval:
7676
"""
7777
Helper class to do asynchronous evaluation of different policies with multiple processes.
7878
It is useful when implementing population based methods like Evolution Strategies (ES),

sb3_contrib/common/wrappers/time_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False)
5050
else:
5151
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=self.dtype)
5252

53-
super(TimeFeatureWrapper, self).__init__(env)
53+
super().__init__(env)
5454

5555
# Try to infer the max number of steps per episode
5656
try:

sb3_contrib/qrdqn/policies.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
activation_fn: Type[nn.Module] = nn.ReLU,
3939
normalize_images: bool = True,
4040
):
41-
super(QuantileNetwork, self).__init__(
41+
super().__init__(
4242
observation_space,
4343
action_space,
4444
features_extractor=features_extractor,
@@ -125,7 +125,7 @@ def __init__(
125125
optimizer_kwargs: Optional[Dict[str, Any]] = None,
126126
):
127127

128-
super(QRDQNPolicy, self).__init__(
128+
super().__init__(
129129
observation_space,
130130
action_space,
131131
features_extractor_class,
@@ -246,7 +246,7 @@ def __init__(
246246
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
247247
optimizer_kwargs: Optional[Dict[str, Any]] = None,
248248
):
249-
super(CnnPolicy, self).__init__(
249+
super().__init__(
250250
observation_space,
251251
action_space,
252252
lr_schedule,
@@ -294,7 +294,7 @@ def __init__(
294294
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
295295
optimizer_kwargs: Optional[Dict[str, Any]] = None,
296296
):
297-
super(MultiInputPolicy, self).__init__(
297+
super().__init__(
298298
observation_space,
299299
action_space,
300300
lr_schedule,

sb3_contrib/qrdqn/qrdqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
_init_setup_model: bool = True,
9494
):
9595

96-
super(QRDQN, self).__init__(
96+
super().__init__(
9797
policy,
9898
env,
9999
learning_rate,
@@ -139,7 +139,7 @@ def __init__(
139139
self._setup_model()
140140

141141
def _setup_model(self) -> None:
142-
super(QRDQN, self)._setup_model()
142+
super()._setup_model()
143143
self._create_aliases()
144144
self.exploration_schedule = get_linear_fn(
145145
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
@@ -253,7 +253,7 @@ def learn(
253253
reset_num_timesteps: bool = True,
254254
) -> OffPolicyAlgorithm:
255255

256-
return super(QRDQN, self).learn(
256+
return super().learn(
257257
total_timesteps=total_timesteps,
258258
callback=callback,
259259
log_interval=log_interval,
@@ -266,7 +266,7 @@ def learn(
266266
)
267267

268268
def _excluded_save_params(self) -> List[str]:
269-
return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"]
269+
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]
270270

271271
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
272272
state_dicts = ["policy", "policy.optimizer"]

0 commit comments

Comments
 (0)