@@ -381,6 +381,8 @@ class LSTMModule(ModuleBase):
381381 Methods:
382382 set_recurrent_mode: controls whether the module should be executed in
383383 recurrent mode.
384+ make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
385+ recurrent states of the RNN.
384386
385387 .. note:: This module relies on specific ``recurrent_state`` keys being present in the input
386388 TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
@@ -521,6 +523,45 @@ def __init__(
521523 self ._recurrent_mode = False
522524
523525 def make_tensordict_primer (self ):
526+ """Makes a tensordict primer for the environment.
527+
528+ A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
529+ inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
530+ processes and dealt with properly.
531+
532+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
533+ in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
534+ tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
535+ are not registered within the environment specs.
536+
537+ Examples:
538+ >>> from torchrl.collectors import SyncDataCollector
539+ >>> from torchrl.envs import TransformedEnv, InitTracker
540+ >>> from torchrl.envs import GymEnv
541+ >>> from torchrl.modules import MLP, LSTMModule
542+ >>> from torch import nn
543+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
544+ >>>
545+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
546+ >>> lstm_module = LSTMModule(
547+ ... input_size=env.observation_spec["observation"].shape[-1],
548+ ... hidden_size=64,
549+ ... in_keys=["observation", "rs_h", "rs_c"],
550+ ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
551+ >>> mlp = MLP(num_cells=[64], out_features=1)
552+ >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
553+ >>> policy(env.reset())
554+ >>> env = env.append_transform(lstm_module.make_tensordict_primer())
555+ >>> data_collector = SyncDataCollector(
556+ ... env,
557+ ... policy,
558+ ... frames_per_batch=10
559+ ... )
560+ >>> for data in data_collector:
561+ ... print(data)
562+ ... break
563+
564+ """
524565 from torchrl .envs .transforms .transforms import TensorDictPrimer
525566
526567 def make_tuple (key ):
@@ -1065,6 +1106,8 @@ class GRUModule(ModuleBase):
10651106 Methods:
10661107 set_recurrent_mode: controls whether the module should be executed in
10671108 recurrent mode.
1109+ make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
1110+ recurrent states of the RNN.
10681111
10691112 .. note:: This module relies on specific ``recurrent_state`` keys being present in the input
10701113 TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
@@ -1230,6 +1273,45 @@ def __init__(
12301273 self ._recurrent_mode = False
12311274
12321275 def make_tensordict_primer (self ):
1276+ """Makes a tensordict primer for the environment.
1277+
1278+ A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
1279+ inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
1280+ processes and dealt with properly.
1281+
1282+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
1283+ in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
1284+ tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
1285+ are not registered within the environment specs.
1286+
1287+ Examples:
1288+ >>> from torchrl.collectors import SyncDataCollector
1289+ >>> from torchrl.envs import TransformedEnv, InitTracker
1290+ >>> from torchrl.envs import GymEnv
1291+ >>> from torchrl.modules import MLP, LSTMModule
1292+ >>> from torch import nn
1293+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
1294+ >>>
1295+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
1296+ >>> gru_module = GRUModule(
1297+ ... input_size=env.observation_spec["observation"].shape[-1],
1298+ ... hidden_size=64,
1299+ ... in_keys=["observation", "rs"],
1300+ ... out_keys=["intermediate", ("next", "rs")])
1301+ >>> mlp = MLP(num_cells=[64], out_features=1)
1302+ >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
1303+ >>> policy(env.reset())
1304+ >>> env = env.append_transform(gru_module.make_tensordict_primer())
1305+ >>> data_collector = SyncDataCollector(
1306+ ... env,
1307+ ... policy,
1308+ ... frames_per_batch=10
1309+ ... )
1310+ >>> for data in data_collector:
1311+ ... print(data)
1312+ ... break
1313+
1314+ """
12331315 from torchrl .envs import TensorDictPrimer
12341316
12351317 def make_tuple (key ):
0 commit comments