@@ -240,7 +240,6 @@ def add(
240
240
done : np .ndarray ,
241
241
infos : List [Dict [str , Any ]],
242
242
) -> None :
243
-
244
243
# Reshape needed when using multiple envs with discrete observations
245
244
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
246
245
if isinstance (self .observation_space , spaces .Discrete ):
@@ -346,7 +345,6 @@ def __init__(
346
345
gamma : float = 0.99 ,
347
346
n_envs : int = 1 ,
348
347
):
349
-
350
348
super ().__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
351
349
self .gae_lambda = gae_lambda
352
350
self .gamma = gamma
@@ -356,7 +354,6 @@ def __init__(
356
354
self .reset ()
357
355
358
356
def reset (self ) -> None :
359
-
360
357
self .observations = np .zeros ((self .buffer_size , self .n_envs ) + self .obs_shape , dtype = np .float32 )
361
358
self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = np .float32 )
362
359
self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -451,7 +448,6 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
451
448
indices = np .random .permutation (self .buffer_size * self .n_envs )
452
449
# Prepare the data
453
450
if not self .generator_ready :
454
-
455
451
_tensor_names = [
456
452
"observations" ,
457
453
"actions" ,
@@ -688,7 +684,6 @@ def __init__(
688
684
gamma : float = 0.99 ,
689
685
n_envs : int = 1 ,
690
686
):
691
-
692
687
super (RolloutBuffer , self ).__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
693
688
694
689
assert isinstance (self .obs_shape , dict ), "DictRolloutBuffer must be used with Dict obs space only"
@@ -763,7 +758,6 @@ def get(
763
758
indices = np .random .permutation (self .buffer_size * self .n_envs )
764
759
# Prepare the data
765
760
if not self .generator_ready :
766
-
767
761
for key , obs in self .observations .items ():
768
762
self .observations [key ] = self .swap_and_flatten (obs )
769
763
@@ -787,7 +781,6 @@ def _get_samples(
787
781
batch_inds : np .ndarray ,
788
782
env : Optional [VecNormalize ] = None ,
789
783
) -> DictRolloutBufferSamples : # type: ignore[signature-mismatch] #FIXME
790
-
791
784
return DictRolloutBufferSamples (
792
785
observations = {key : self .to_torch (obs [batch_inds ]) for (key , obs ) in self .observations .items ()},
793
786
actions = self .to_torch (self .actions [batch_inds ]),
0 commit comments