@@ -474,7 +474,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
474
474
yield self ._get_samples (indices [start_idx : start_idx + batch_size ])
475
475
start_idx += batch_size
476
476
477
- def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> RolloutBufferSamples :
477
+ def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> RolloutBufferSamples : # type: ignore[signature-mismatch] #FIXME
478
478
data = (
479
479
self .observations [batch_inds ],
480
480
self .actions [batch_inds ],
@@ -603,7 +603,7 @@ def add(
603
603
self .full = True
604
604
self .pos = 0
605
605
606
- def sample (self , batch_size : int , env : Optional [VecNormalize ] = None ) -> DictReplayBufferSamples :
606
+ def sample (self , batch_size : int , env : Optional [VecNormalize ] = None ) -> DictReplayBufferSamples : # type: ignore[signature-mismatch] #FIXME:
607
607
"""
608
608
Sample elements from the replay buffer.
609
609
@@ -614,7 +614,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictRep
614
614
"""
615
615
return super (ReplayBuffer , self ).sample (batch_size = batch_size , env = env )
616
616
617
- def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> DictReplayBufferSamples :
617
+ def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> DictReplayBufferSamples : # type: ignore[signature-mismatch] #FIXME:
618
618
# Sample randomly the env idx
619
619
env_indices = np .random .randint (0 , high = self .n_envs , size = (len (batch_inds ),))
620
620
@@ -743,7 +743,7 @@ def add(
743
743
if self .pos == self .buffer_size :
744
744
self .full = True
745
745
746
- def get (self , batch_size : Optional [int ] = None ) -> Generator [DictRolloutBufferSamples , None , None ]:
746
+ def get (self , batch_size : Optional [int ] = None ) -> Generator [DictRolloutBufferSamples , None , None ]: # type: ignore[signature-mismatch] #FIXME
747
747
assert self .full , ""
748
748
indices = np .random .permutation (self .buffer_size * self .n_envs )
749
749
# Prepare the data
@@ -767,7 +767,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSa
767
767
yield self ._get_samples (indices [start_idx : start_idx + batch_size ])
768
768
start_idx += batch_size
769
769
770
- def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> DictRolloutBufferSamples :
770
+ def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> DictRolloutBufferSamples : # type: ignore[signature-mismatch] #FIXME
771
771
772
772
return DictRolloutBufferSamples (
773
773
observations = {key : self .to_torch (obs [batch_inds ]) for (key , obs ) in self .observations .items ()},
0 commit comments