diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 976f57dd5b9..8253fda80ed 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -19,6 +19,7 @@ # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell from torchrl._utils import timeit +from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP @@ -261,6 +262,8 @@ def forward(self, tensordict): tensordict_out.append(_tensordict) if t < time_steps - 1: + # Translate ("next", *) to the non-next key required for the current step input + _tensordict = step_mdp(_tensordict, keep_other=True) _tensordict = _tensordict.select(*self.in_keys, strict=False) _tensordict = update_values[t + 1].update(_tensordict)