Skip to content

Commit 82bc63f

Browse files
qgallouedecaraffin
andauthored
Upgrade black formatting (#1310)
* apply black * Reformat tests --------- Co-authored-by: Antonin Raffin <[email protected]>
1 parent bea3c44 commit 82bc63f

22 files changed

+4
-44
lines changed

stable_baselines3/a2c/a2c.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def __init__(
8181
device: Union[th.device, str] = "auto",
8282
_init_setup_model: bool = True,
8383
):
84-
8584
super().__init__(
8685
policy,
8786
env,
@@ -132,7 +131,6 @@ def train(self) -> None:
132131

133132
# This will only loop once (get all data in one go)
134133
for rollout_data in self.rollout_buffer.get(batch_size=None):
135-
136134
actions = rollout_data.actions
137135
if isinstance(self.action_space, spaces.Discrete):
138136
# Convert discrete action from float to long
@@ -189,7 +187,6 @@ def learn(
189187
reset_num_timesteps: bool = True,
190188
progress_bar: bool = False,
191189
) -> SelfA2C:
192-
193190
return super().learn(
194191
total_timesteps=total_timesteps,
195192
callback=callback,

stable_baselines3/common/buffers.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def add(
240240
done: np.ndarray,
241241
infos: List[Dict[str, Any]],
242242
) -> None:
243-
244243
# Reshape needed when using multiple envs with discrete observations
245244
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
246245
if isinstance(self.observation_space, spaces.Discrete):
@@ -346,7 +345,6 @@ def __init__(
346345
gamma: float = 0.99,
347346
n_envs: int = 1,
348347
):
349-
350348
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
351349
self.gae_lambda = gae_lambda
352350
self.gamma = gamma
@@ -356,7 +354,6 @@ def __init__(
356354
self.reset()
357355

358356
def reset(self) -> None:
359-
360357
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
361358
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
362359
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -451,7 +448,6 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
451448
indices = np.random.permutation(self.buffer_size * self.n_envs)
452449
# Prepare the data
453450
if not self.generator_ready:
454-
455451
_tensor_names = [
456452
"observations",
457453
"actions",
@@ -688,7 +684,6 @@ def __init__(
688684
gamma: float = 0.99,
689685
n_envs: int = 1,
690686
):
691-
692687
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
693688

694689
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
@@ -763,7 +758,6 @@ def get(
763758
indices = np.random.permutation(self.buffer_size * self.n_envs)
764759
# Prepare the data
765760
if not self.generator_ready:
766-
767761
for key, obs in self.observations.items():
768762
self.observations[key] = self.swap_and_flatten(obs)
769763

@@ -787,7 +781,6 @@ def _get_samples(
787781
batch_inds: np.ndarray,
788782
env: Optional[VecNormalize] = None,
789783
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
790-
791784
return DictRolloutBufferSamples(
792785
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
793786
actions=self.to_torch(self.actions[batch_inds]),

stable_baselines3/common/callbacks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,9 @@ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any
429429
self._is_success_buffer.append(maybe_is_success)
430430

431431
def _on_step(self) -> bool:
432-
433432
continue_training = True
434433

435434
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
436-
437435
# Sync training and eval env if there is VecNormalize
438436
if self.model.get_vec_normalize_env() is not None:
439437
try:

stable_baselines3/common/evaluation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def evaluate_policy(
9191
current_lengths += 1
9292
for i in range(n_envs):
9393
if episode_counts[i] < episode_count_targets[i]:
94-
9594
# unpack values so that the callback can access the local variables
9695
reward = rewards[i]
9796
done = dones[i]

stable_baselines3/common/logger.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
173173
key2str = {}
174174
tag = None
175175
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
176-
177176
if excluded is not None and ("stdout" in excluded or "log" in excluded):
178177
continue
179178

@@ -342,7 +341,7 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
342341
self.file.seek(0)
343342
lines = self.file.readlines()
344343
self.file.seek(0)
345-
for (i, key) in enumerate(self.keys):
344+
for i, key in enumerate(self.keys):
346345
if i > 0:
347346
self.file.write(",")
348347
self.file.write(key)
@@ -399,9 +398,7 @@ def __init__(self, folder: str):
399398
self.writer = SummaryWriter(log_dir=folder)
400399

401400
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
402-
403401
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
404-
405402
if excluded is not None and "tensorboard" in excluded:
406403
continue
407404

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def __init__(
102102
sde_support: bool = True,
103103
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
104104
):
105-
106105
super().__init__(
107106
policy=policy,
108107
env=env,
@@ -319,7 +318,6 @@ def learn(
319318
reset_num_timesteps: bool = True,
320319
progress_bar: bool = False,
321320
) -> SelfOffPolicyAlgorithm:
322-
323321
total_timesteps, callback = self._setup_learn(
324322
total_timesteps,
325323
callback,

stable_baselines3/common/on_policy_algorithm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(
7272
_init_setup_model: bool = True,
7373
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
7474
):
75-
7675
super().__init__(
7776
policy=policy,
7877
env=env,
@@ -244,7 +243,6 @@ def learn(
244243
callback.on_training_start(locals(), globals())
245244

246245
while self.num_timesteps < total_timesteps:
247-
248246
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
249247

250248
if continue_training is False:

stable_baselines3/common/policies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,6 @@ def __init__(
433433
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
434434
optimizer_kwargs: Optional[Dict[str, Any]] = None,
435435
):
436-
437436
if optimizer_kwargs is None:
438437
optimizer_kwargs = {}
439438
# Small values to avoid NaN in Adam optimizer

stable_baselines3/common/results_plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def plot_curves(
8484
plt.figure(title, figsize=figsize)
8585
max_x = max(xy[0][-1] for xy in xy_list)
8686
min_x = 0
87-
for (_, (x, y)) in enumerate(xy_list):
87+
for _, (x, y) in enumerate(xy_list):
8888
plt.scatter(x, y, s=2)
8989
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
9090
if x.shape[0] >= EPISODES_WINDOW:

stable_baselines3/common/save_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def load_from_zip_file(
367367
device: Union[th.device, str] = "auto",
368368
verbose: int = 0,
369369
print_system_info: bool = False,
370-
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
370+
) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]:
371371
"""
372372
Load model data from a .zip archive
373373

0 commit comments

Comments
 (0)