Skip to content

Commit 0063741

Browse files
author
Vincent Moens
authored
[Doc] Better doc for make_tensordict_primer (#2324)
1 parent 935e8da commit 0063741

File tree

1 file changed

+82
-0
lines changed
  • torchrl/modules/tensordict_module

1 file changed

+82
-0
lines changed

torchrl/modules/tensordict_module/rnn.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ class LSTMModule(ModuleBase):
381381
Methods:
382382
set_recurrent_mode: controls whether the module should be executed in
383383
recurrent mode.
384+
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
385+
recurrent states of the RNN.
384386
385387
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
386388
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
@@ -521,6 +523,45 @@ def __init__(
521523
self._recurrent_mode = False
522524

523525
def make_tensordict_primer(self):
526+
"""Makes a tensordict primer for the environment.
527+
528+
A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
529+
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
530+
processes and dealt with properly.
531+
532+
Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
533+
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
534+
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
535+
are not registered within the environment specs.
536+
537+
Examples:
538+
>>> from torchrl.collectors import SyncDataCollector
539+
>>> from torchrl.envs import TransformedEnv, InitTracker
540+
>>> from torchrl.envs import GymEnv
541+
>>> from torchrl.modules import MLP, LSTMModule
542+
>>> from torch import nn
543+
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
544+
>>>
545+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
546+
>>> lstm_module = LSTMModule(
547+
... input_size=env.observation_spec["observation"].shape[-1],
548+
... hidden_size=64,
549+
... in_keys=["observation", "rs_h", "rs_c"],
550+
... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
551+
>>> mlp = MLP(num_cells=[64], out_features=1)
552+
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
553+
>>> policy(env.reset())
554+
>>> env = env.append_transform(lstm_module.make_tensordict_primer())
555+
>>> data_collector = SyncDataCollector(
556+
... env,
557+
... policy,
558+
... frames_per_batch=10
559+
... )
560+
>>> for data in data_collector:
561+
... print(data)
562+
... break
563+
564+
"""
524565
from torchrl.envs.transforms.transforms import TensorDictPrimer
525566

526567
def make_tuple(key):
@@ -1065,6 +1106,8 @@ class GRUModule(ModuleBase):
10651106
Methods:
10661107
set_recurrent_mode: controls whether the module should be executed in
10671108
recurrent mode.
1109+
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
1110+
recurrent states of the RNN.
10681111
10691112
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
10701113
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
@@ -1230,6 +1273,45 @@ def __init__(
12301273
self._recurrent_mode = False
12311274

12321275
def make_tensordict_primer(self):
1276+
"""Makes a tensordict primer for the environment.
1277+
1278+
A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
1279+
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
1280+
processes and dealt with properly.
1281+
1282+
Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
1283+
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
1284+
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
1285+
are not registered within the environment specs.
1286+
1287+
Examples:
1288+
>>> from torchrl.collectors import SyncDataCollector
1289+
>>> from torchrl.envs import TransformedEnv, InitTracker
1290+
>>> from torchrl.envs import GymEnv
1291+
>>> from torchrl.modules import MLP, LSTMModule
1292+
>>> from torch import nn
1293+
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
1294+
>>>
1295+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
1296+
>>> gru_module = GRUModule(
1297+
... input_size=env.observation_spec["observation"].shape[-1],
1298+
... hidden_size=64,
1299+
... in_keys=["observation", "rs"],
1300+
... out_keys=["intermediate", ("next", "rs")])
1301+
>>> mlp = MLP(num_cells=[64], out_features=1)
1302+
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
1303+
>>> policy(env.reset())
1304+
>>> env = env.append_transform(gru_module.make_tensordict_primer())
1305+
>>> data_collector = SyncDataCollector(
1306+
... env,
1307+
... policy,
1308+
... frames_per_batch=10
1309+
... )
1310+
>>> for data in data_collector:
1311+
... print(data)
1312+
... break
1313+
1314+
"""
12331315
from torchrl.envs import TensorDictPrimer
12341316

12351317
def make_tuple(key):

0 commit comments

Comments
 (0)