Skip to content

Commit 5e03a4c

Browse files
matteobettinivmoens
andcommitted
[BugFix] Remove reset on last step of a rollout (#1936)
Co-authored-by: vmoens <[email protected]>
1 parent ff2e265 commit 5e03a4c

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

test/test_env.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,35 @@ def test_rollout(env_name, frame_skip, seed=0):
217217
env.close()
218218

219219

220+
@pytest.mark.parametrize("max_steps", [1, 5])
221+
def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4):
222+
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
223+
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
224+
policy = CountingEnvCountPolicy(
225+
action_spec=env.action_spec, action_key=env.action_key
226+
)
227+
228+
input_td = env.reset()
229+
for _ in range(epochs):
230+
rollout_td = env.rollout(
231+
max_steps=max_steps,
232+
policy=policy,
233+
auto_reset=False,
234+
break_when_any_done=False,
235+
tensordict=input_td,
236+
)
237+
assert (env.count == max_steps).all()
238+
input_td = step_mdp(
239+
rollout_td[..., -1],
240+
keep_other=True,
241+
exclude_action=False,
242+
exclude_reward=True,
243+
reward_keys=env.reward_keys,
244+
action_keys=env.action_keys,
245+
done_keys=env.done_keys,
246+
)
247+
248+
220249
@pytest.mark.parametrize("device", get_default_devices())
221250
def test_rollout_predictability(device):
222251
env = MockSerialEnv(device=device)

torchrl/envs/common.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2273,7 +2273,9 @@ def rollout(
22732273
called on the sub-envs that are done. Default is True.
22742274
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
22752275
tensordict (TensorDict, optional): if auto_reset is False, an initial
2276-
tensordict must be provided.
2276+
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
2277+
environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
2278+
output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.
22772279
22782280
Returns:
22792281
TensorDict object containing the resulting trajectory.
@@ -2369,6 +2371,26 @@ def rollout(
23692371
>>> print(rollout.names)
23702372
[None, 'time']
23712373
2374+
Rollouts can be used in a loop to emulate data collection.
2375+
To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
2376+
:func:`~torchrl.envs.utils.step_mdp` on it.
2377+
2378+
Examples:
2379+
>>> from torchrl.envs import GymEnv, step_mdp
2380+
>>> env = GymEnv("CartPole-v1")
2381+
>>> epochs = 10
2382+
>>> input_td = env.reset()
2383+
>>> for i in range(epochs):
2384+
... rollout_td = env.rollout(
2385+
... max_steps=100,
2386+
... break_when_any_done=False,
2387+
... auto_reset=False,
2388+
... tensordict=input_td,
2389+
... )
2390+
... input_td = step_mdp(
2391+
... rollout_td[..., -1],
2392+
... )
2393+
23722394
"""
23732395
if auto_cast_to_device:
23742396
try:
@@ -2388,6 +2410,9 @@ def rollout(
23882410
tensordict = self.reset()
23892411
elif tensordict is None:
23902412
raise RuntimeError("tensordict must be provided when auto_reset is False")
2413+
else:
2414+
tensordict = self.maybe_reset(tensordict)
2415+
23912416
if policy is None:
23922417

23932418
policy = self.rand_action
@@ -2493,7 +2518,10 @@ def _rollout_nonstop(
24932518
tensordict_ = tensordict_.to(env_device, non_blocking=True)
24942519
else:
24952520
tensordict_.clear_device_()
2496-
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
2521+
if i == max_steps - 1:
2522+
tensordict = self.step(tensordict_)
2523+
else:
2524+
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
24972525
tensordicts.append(tensordict)
24982526
if i == max_steps - 1:
24992527
# we don't truncated as one could potentially continue the run
@@ -2557,14 +2585,28 @@ def step_and_maybe_reset(
25572585
action_keys=self.action_keys,
25582586
done_keys=self.done_keys,
25592587
)
2588+
tensordict_ = self.maybe_reset(tensordict_)
2589+
return tensordict, tensordict_
2590+
2591+
def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
2592+
"""Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
2593+
2594+
Args:
2595+
tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
2596+
2597+
Returns:
2598+
A tensordict that is identical to the input where the environment was
2599+
not reset and contains the new reset data where the environment was reset.
2600+
2601+
"""
25602602
any_done = _terminated_or_truncated(
2561-
tensordict_,
2603+
tensordict,
25622604
full_done_spec=self.output_spec["full_done_spec"],
25632605
key="_reset",
25642606
)
25652607
if any_done:
2566-
tensordict_ = self.reset(tensordict_)
2567-
return tensordict, tensordict_
2608+
tensordict = self.reset(tensordict)
2609+
return tensordict
25682610

25692611
def empty_cache(self):
25702612
"""Erases all the cached values.

0 commit comments

Comments
 (0)