Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# from torchrl.modules.tensordict_module.rnn import GRUCell
from torch.nn import GRUCell
from torchrl._utils import timeit
from torchrl.envs.utils import step_mdp

from torchrl.modules.models.models import MLP

Expand Down Expand Up @@ -261,6 +262,8 @@ def forward(self, tensordict):

tensordict_out.append(_tensordict)
if t < time_steps - 1:
# Translate ("next", *) to the non-next key required for the current step input
_tensordict = step_mdp(_tensordict, keep_other=True)
_tensordict = _tensordict.select(*self.in_keys, strict=False)
_tensordict = update_values[t + 1].update(_tensordict)

Expand Down