77import pathlib
88from abc import ABC , abstractmethod
99from dataclasses import dataclass
10- from typing import Any , Dict , Iterable , Optional , Tuple , Type
10+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Type
1111
1212from tensordict import TensorDictBase
1313from tensordict .nn import TensorDictModule , TensorDictSequential
1919 TensorDictReplayBuffer ,
2020)
2121from torchrl .data .replay_buffers import RandomSampler , SamplerWithoutReplacement
22+ from torchrl .envs import Compose , Transform
2223from torchrl .objectives import LossModule
2324from torchrl .objectives .utils import HardUpdate , SoftUpdate , TargetNetUpdater
2425
@@ -132,15 +133,15 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
132133 return self ._losses_and_updaters [group ]
133134
134135 def get_replay_buffer (
135- self ,
136- group : str ,
136+ self , group : str , transforms : List [Transform ] = None
137137 ) -> ReplayBuffer :
138138 """
139139 Get the ReplayBuffer for a specific group.
140140 This function will check ``self.on_policy`` and create the buffer accordingly
141141
142142 Args:
143143 group (str): agent group of the loss and updater
144+ transforms (optional, list of Transform): Transforms to apply to the replay buffer ``.sample()`` call
144145
145146 Returns: ReplayBuffer the group
146147 """
@@ -154,6 +155,7 @@ def get_replay_buffer(
154155 sampler = sampler ,
155156 batch_size = sampling_size ,
156157 priority_key = (group , "td_error" ),
158+ transform = Compose (* transforms ) if transforms is not None else None ,
157159 )
158160
159161 def get_policy_for_loss (self , group : str ) -> TensorDictModule :
0 commit comments