@@ -640,7 +640,7 @@ class ReplayBufferTrainer(TrainerHookBase):
640
640
memmap (bool, optional): if ``True``, a memmap tensordict is created.
641
641
Default is ``False``.
642
642
device (device, optional): device where the samples must be placed.
643
- Default is ``cpu ``.
643
+ Default to ``None ``.
644
644
flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be
645
645
flattened (or equivalently masked with the valid mask obtained from
646
646
the collector) before being passed to the replay buffer. Otherwise,
@@ -666,7 +666,7 @@ def __init__(
666
666
replay_buffer : TensorDictReplayBuffer ,
667
667
batch_size : Optional [int ] = None ,
668
668
memmap : bool = False ,
669
- device : DEVICE_TYPING = "cpu" ,
669
+ device : DEVICE_TYPING | None = None ,
670
670
flatten_tensordicts : bool = False ,
671
671
max_dims : Optional [Sequence [int ]] = None ,
672
672
) -> None :
@@ -695,15 +695,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase:
695
695
pads += [0 , pad_value ]
696
696
batch = pad (batch , pads )
697
697
batch = batch .cpu ()
698
- if self .memmap :
699
- # We can already place the tensords on the device if they're memmap,
700
- # as this is a lazy op
701
- batch = batch .memmap_ ().to (self .device )
702
698
self .replay_buffer .extend (batch )
703
699
704
700
def sample (self , batch : TensorDictBase ) -> TensorDictBase :
705
701
sample = self .replay_buffer .sample (batch_size = self .batch_size )
706
- return sample .to (self .device , non_blocking = True )
702
+ return sample .to (self .device ) if self . device is not None else sample
707
703
708
704
def update_priority (self , batch : TensorDictBase ) -> None :
709
705
self .replay_buffer .update_tensordict_priority (batch )
0 commit comments