@@ -2273,7 +2273,9 @@ def rollout(
22732273 called on the sub-envs that are done. Default is True.
22742274 return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
22752275 tensordict (TensorDict, optional): if auto_reset is False, an initial
2276- tensordict must be provided.
2276+ tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
2277+ environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
2278+ output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.
22772279
22782280 Returns:
22792281 TensorDict object containing the resulting trajectory.
@@ -2369,6 +2371,26 @@ def rollout(
23692371 >>> print(rollout.names)
23702372 [None, 'time']
23712373
2374+ Rollouts can be used in a loop to emulate data collection.
2375+ To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
2376+ :func:`~torchrl.envs.utils.step_mdp` on it.
2377+
2378+ Examples:
2379+ >>> from torchrl.envs import GymEnv, step_mdp
2380+ >>> env = GymEnv("CartPole-v1")
2381+ >>> epochs = 10
2382+ >>> input_td = env.reset()
2383+ >>> for i in range(epochs):
2384+ ... rollout_td = env.rollout(
2385+ ... max_steps=100,
2386+ ... break_when_any_done=False,
2387+ ... auto_reset=False,
2388+ ... tensordict=input_td,
2389+ ... )
2390+ ... input_td = step_mdp(
2391+ ... rollout_td[..., -1],
2392+ ... )
2393+
23722394 """
23732395 if auto_cast_to_device :
23742396 try :
@@ -2388,6 +2410,9 @@ def rollout(
23882410 tensordict = self .reset ()
23892411 elif tensordict is None :
23902412 raise RuntimeError ("tensordict must be provided when auto_reset is False" )
2413+ else :
2414+ tensordict = self .maybe_reset (tensordict )
2415+
23912416 if policy is None :
23922417
23932418 policy = self .rand_action
@@ -2493,7 +2518,10 @@ def _rollout_nonstop(
24932518 tensordict_ = tensordict_ .to (env_device , non_blocking = True )
24942519 else :
24952520 tensordict_ .clear_device_ ()
2496- tensordict , tensordict_ = self .step_and_maybe_reset (tensordict_ )
2521+ if i == max_steps - 1 :
2522+ tensordict = self .step (tensordict_ )
2523+ else :
2524+ tensordict , tensordict_ = self .step_and_maybe_reset (tensordict_ )
24972525 tensordicts .append (tensordict )
24982526 if i == max_steps - 1 :
24992527 # we don't truncated as one could potentially continue the run
@@ -2557,14 +2585,28 @@ def step_and_maybe_reset(
25572585 action_keys = self .action_keys ,
25582586 done_keys = self .done_keys ,
25592587 )
2588+ tensordict_ = self .maybe_reset (tensordict_ )
2589+ return tensordict , tensordict_
2590+
2591+ def maybe_reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
2592+ """Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
2593+
2594+ Args:
2595+ tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
2596+
2597+ Returns:
2598+ A tensordict that is identical to the input where the environment was
2599+ not reset and contains the new reset data where the environment was reset.
2600+
2601+ """
25602602 any_done = _terminated_or_truncated (
2561- tensordict_ ,
2603+ tensordict ,
25622604 full_done_spec = self .output_spec ["full_done_spec" ],
25632605 key = "_reset" ,
25642606 )
25652607 if any_done :
2566- tensordict_ = self .reset (tensordict_ )
2567- return tensordict , tensordict_
2608+ tensordict = self .reset (tensordict )
2609+ return tensordict
25682610
25692611 def empty_cache (self ):
25702612 """Erases all the cached values.
0 commit comments