Skip to content

Commit 69fdf15

Browse files
authored
Downgrade sphinx-autodoc-typehints (#1291)
* Update setup.py * black * hotfix pytype
1 parent 92f7a6f commit 69fdf15

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
# For spelling
118118
"sphinxcontrib.spelling",
119119
# Type hints support
120-
"sphinx-autodoc-typehints",
120+
"sphinx-autodoc-typehints==1.21.1", # TODO: remove version constraint, see #1290
121121
# Copy button for code snippets
122122
"sphinx_copybutton",
123123
],

stable_baselines3/common/buffers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
474474
yield self._get_samples(indices[start_idx : start_idx + batch_size])
475475
start_idx += batch_size
476476

477-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
477+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
478478
data = (
479479
self.observations[batch_inds],
480480
self.actions[batch_inds],
@@ -603,7 +603,7 @@ def add(
603603
self.full = True
604604
self.pos = 0
605605

606-
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
606+
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
607607
"""
608608
Sample elements from the replay buffer.
609609
@@ -614,7 +614,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictRep
614614
"""
615615
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
616616

617-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
617+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
618618
# Sample randomly the env idx
619619
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
620620

@@ -743,7 +743,7 @@ def add(
743743
if self.pos == self.buffer_size:
744744
self.full = True
745745

746-
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
746+
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
747747
assert self.full, ""
748748
indices = np.random.permutation(self.buffer_size * self.n_envs)
749749
# Prepare the data
@@ -767,7 +767,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSa
767767
yield self._get_samples(indices[start_idx : start_idx + batch_size])
768768
start_idx += batch_size
769769

770-
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
770+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
771771

772772
return DictRolloutBufferSamples(
773773
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},

stable_baselines3/common/envs/identity_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l
7171
super().__init__(ep_length=ep_length, space=space)
7272
self.eps = eps
7373

74-
def step(self, action: np.ndarray) -> GymStepReturn:
74+
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
7575
reward = self._get_reward(action)
7676
self._choose_next_state()
7777
self.current_step += 1
@@ -83,7 +83,7 @@ def _get_reward(self, action: np.ndarray) -> float:
8383

8484

8585
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
86-
def __init__(self, dim: int = 1, ep_length: int = 100):
86+
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
8787
"""
8888
Identity environment for testing purposes
8989
@@ -95,7 +95,7 @@ def __init__(self, dim: int = 1, ep_length: int = 100):
9595

9696

9797
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
98-
def __init__(self, dim: int = 1, ep_length: int = 100):
98+
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
9999
"""
100100
Identity environment for testing purposes
101101
@@ -126,7 +126,7 @@ def __init__(
126126
n_channels: int = 1,
127127
discrete: bool = True,
128128
channel_first: bool = False,
129-
):
129+
) -> None:
130130
self.observation_shape = (screen_height, screen_width, n_channels)
131131
if channel_first:
132132
self.observation_shape = (n_channels, screen_height, screen_width)

0 commit comments

Comments
 (0)