Skip to content

Commit 38544a5

Browse files
authored
[BE] Catching common errors in env.rollout and rb.add (#3102)
1 parent 1eccb49 commit 38544a5

File tree

6 files changed

+139
-25
lines changed

6 files changed

+139
-25
lines changed

test/mocking_classes.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,3 +2588,69 @@ def _step(self, tensordict: TensorDict) -> TensorDict:
25882588

25892589
def _set_seed(self):
25902590
pass
2591+
2592+
2593+
class EnvThatErrorsBecauseOfStack(EnvBase):
2594+
def __init__(self, target: int = 5, batch_size: int | None = None):
2595+
super().__init__(device="cpu", batch_size=batch_size)
2596+
self.target = target
2597+
self.observation_spec = Bounded(
2598+
low=0, high=self.target, shape=(1,), dtype=torch.int64
2599+
)
2600+
self.action_spec = Categorical(n=2, shape=(1,), dtype=torch.int64)
2601+
self.reward_spec = Unbounded(shape=(1,), dtype=torch.float32)
2602+
self.done_spec = Categorical(n=2, shape=(1,), dtype=torch.bool)
2603+
2604+
def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
2605+
if tensordict is None:
2606+
tensordict = TensorDict(batch_size=self.batch_size, device=self.device)
2607+
2608+
observation = torch.zeros(
2609+
self.batch_size, dtype=self.observation_spec.dtype, device=self.device
2610+
)
2611+
reward = torch.zeros(
2612+
self.batch_size + torch.Size([1]),
2613+
dtype=self.reward_spec.dtype,
2614+
device=self.device,
2615+
)
2616+
done = torch.zeros(
2617+
self.batch_size + torch.Size([1]), dtype=torch.bool, device=self.device
2618+
)
2619+
terminated = torch.zeros_like(done)
2620+
action = torch.zeros(
2621+
self.batch_size + torch.Size([1]), dtype=torch.int64, device=self.device
2622+
)
2623+
2624+
tensordict.set(self.observation_keys[0], observation)
2625+
tensordict.set(self.reward_key, reward)
2626+
tensordict.set(self.done_keys[0], done)
2627+
tensordict.set("terminated", terminated)
2628+
tensordict.set(self.action_keys[0], action)
2629+
2630+
return tensordict
2631+
2632+
def _step(self, tensordict: TensorDict) -> TensorDict:
2633+
obs = tensordict.get(
2634+
self.observation_keys[0]
2635+
) # the counter value or the counters value if it is several batchs
2636+
action = tensordict.get(self.action_keys[0]).squeeze(-1)
2637+
2638+
new_obs = obs + (action == 1).to(obs.dtype)
2639+
new_obs = new_obs.clamp_max(self.target)
2640+
reward = (new_obs == self.target).to(self.reward_spec.dtype).unsqueeze(-1)
2641+
done = (new_obs == self.target).to(torch.bool).unsqueeze(-1)
2642+
terminated = done.clone()
2643+
return TensorDict(
2644+
{
2645+
self.observation_keys[0]: new_obs,
2646+
self.reward_keys[0]: reward,
2647+
self.done_keys[0]: done,
2648+
"terminated": terminated,
2649+
self.action_keys[0]: action.unsqueeze(-1),
2650+
},
2651+
batch_size=self.batch_size,
2652+
device=self.device,
2653+
)
2654+
2655+
def _set_seed(self, seed: int | None) -> None:
2656+
return 0

test/test_env.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
DiscreteActionVecMockEnv,
137137
DummyModelBasedEnvBase,
138138
EnvThatDoesNothing,
139+
EnvThatErrorsBecauseOfStack,
139140
EnvWithDynamicSpec,
140141
EnvWithMetadata,
141142
EnvWithTensorClass,
@@ -178,6 +179,7 @@
178179
DiscreteActionVecMockEnv,
179180
DummyModelBasedEnvBase,
180181
EnvThatDoesNothing,
182+
EnvThatErrorsBecauseOfStack,
181183
EnvWithDynamicSpec,
182184
EnvWithMetadata,
183185
EnvWithTensorClass,
@@ -344,6 +346,20 @@ def forward(self, values):
344346
)
345347
env.rollout(10, policy)
346348

349+
def test_stack_error(self):
350+
env = EnvThatErrorsBecauseOfStack()
351+
assert not env._has_dynamic_specs
352+
cm = pytest.raises(
353+
RuntimeError,
354+
match="The reward key was present in the root tensordict of at least one of the tensordicts to stack",
355+
)
356+
with cm:
357+
env.check_env_specs()
358+
with cm:
359+
env.rollout(10, break_when_any_done=True, return_contiguous=True)
360+
with cm:
361+
env.rollout(10, break_when_any_done=False, return_contiguous=True)
362+
347363
@pytest.mark.parametrize("dynamic_shape", [True, False])
348364
def test_make_spec_from_td(self, dynamic_shape):
349365
data = TensorDict(

test/test_rb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,15 @@ def test_batch_errors():
17901790
rb.sample()
17911791

17921792

1793+
def test_add_warning():
1794+
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
1795+
with pytest.warns(
1796+
UserWarning,
1797+
match=r"Using `add\(\)` with a TensorDict that has batch_size",
1798+
):
1799+
rb.add(TensorDict(batch_size=[1]))
1800+
1801+
17931802
@pytest.mark.parametrize("priority_key", ["pk", "td_error"])
17941803
@pytest.mark.parametrize("contiguous", [True, False])
17951804
@pytest.mark.parametrize("device", get_default_devices())

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from torch import Tensor
4040
from torch.utils._pytree import tree_map
4141

42-
from torchrl._utils import accept_remote_rref_udf_invocation
42+
from torchrl._utils import accept_remote_rref_udf_invocation, RL_WARNINGS
4343
from torchrl.data.replay_buffers.samplers import (
4444
PrioritizedSampler,
4545
RandomSampler,
@@ -719,6 +719,13 @@ def add(self, data: Any) -> int:
719719
data = None
720720
if data is None:
721721
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
722+
if RL_WARNINGS and is_tensor_collection(data) and data.ndim:
723+
warnings.warn(
724+
f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. "
725+
f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). "
726+
"You can silence this warning by setting the `RL_WARNINGS` environment variable to `'0'`."
727+
)
728+
722729
return self._add(data)
723730

724731
def _add(self, data):

torchrl/envs/common.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import abc
9+
import re
910
import warnings
1011
import weakref
1112
from copy import deepcopy
@@ -725,7 +726,6 @@ def auto_specs_(
725726

726727
return self
727728

728-
@wraps(check_env_specs_func)
729729
def check_env_specs(self, *args, **kwargs):
730730
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
731731
return check_env_specs_func(self, *args, **kwargs)
@@ -2927,7 +2927,7 @@ def _reset_check_done(self, tensordict, tensordict_reset):
29272927
):
29282928
warnings.warn(
29292929
f"A partial `'_reset'` key has been passed to `reset` ({reset_key}), "
2930-
f"but the corresponding done_key ({done_key}) was not present in the input "
2930+
f"but the corresponding done_key ({done_key}) wasn't present in the input "
29312931
f"tensordict. "
29322932
f"This is discouraged, since the input tensordict should contain "
29332933
f"all the data not being reset."
@@ -3387,12 +3387,26 @@ def rollout(
33873387
out_td = torch.stack(tensordicts, len(batch_size), out=out)
33883388
except RuntimeError as err:
33893389
if (
3390-
"The shapes of the tensors to stack is incompatible" in str(err)
3390+
re.match(
3391+
"The shapes of the tensors to stack is incompatible", str(err)
3392+
)
33913393
and self._has_dynamic_specs
33923394
):
33933395
raise RuntimeError(
33943396
"The environment specs are dynamic. Call rollout with return_contiguous=False."
33953397
)
3398+
if re.match(
3399+
"The sets of keys in the tensordicts to stack are exclusive",
3400+
str(err),
3401+
):
3402+
for reward_key in self.reward_keys:
3403+
if any(reward_key in td for td in tensordicts):
3404+
raise RuntimeError(
3405+
"The reward key was present in the root tensordict of at least one of the tensordicts to stack. "
3406+
"The likely cause is that your environment returns a reward during a call to `reset`, which is not allowed. "
3407+
"To fix this, you should return the reward in the `step` method but not in during `reset`. If you need a reward "
3408+
"to be returned during `reset`, submit an issue on github."
3409+
)
33963410
raise
33973411
else:
33983412
out_td = LazyStackedTensorDict.maybe_dense_stack(
@@ -3967,7 +3981,7 @@ def __getattr__(self, attr: str) -> Any:
39673981
super().__getattr__(attr)
39683982

39693983
raise AttributeError(
3970-
f"env not set in {self.__class__.__name__}, cannot access {attr}"
3984+
f"The env wasn't set in {self.__class__.__name__}, cannot access {attr}"
39713985
)
39723986

39733987
@abc.abstractmethod

torchrl/envs/utils.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _is_reset(key: NestedKey):
191191
"extra keys can be present in the input TensorDict). "
192192
"As a result, step_mdp will need to run extra key checks at each iteration. "
193193
f"{{Expected keys}}-{{Actual keys}}={set(expected) - actual} (<= this set should be empty), \n"
194-
f"{{Actual keys}}-{{Expected keys}}={actual- set(expected)}."
194+
f"{{Actual keys}}-{{Expected keys}}={actual - set(expected)}."
195195
)
196196
return self.validated
197197

@@ -689,7 +689,7 @@ def check_env_specs(
689689
check_dtype=True,
690690
seed: int | None = None,
691691
tensordict: TensorDictBase | None = None,
692-
break_when_any_done: bool | Literal["both"] = None,
692+
break_when_any_done: bool | Literal["both"] | None = None,
693693
):
694694
"""Tests an environment specs against the results of short rollout.
695695
@@ -786,10 +786,12 @@ def check_env_specs(
786786
real_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor)
787787
)
788788
if fake_tensordict_keys != real_tensordict_keys:
789+
keys_in_real_not_in_fake = real_tensordict_keys - fake_tensordict_keys
790+
keys_in_fake_not_in_real = fake_tensordict_keys - real_tensordict_keys
789791
raise AssertionError(
790792
f"""The keys of the specs and data do not match:
791-
- List of keys present in real but not in fake: {real_tensordict_keys-fake_tensordict_keys},
792-
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
793+
- List of keys present in real but not in fake: {keys_in_real_not_in_fake=},
794+
- List of keys present in fake but not in real: {keys_in_fake_not_in_real=}.
793795
"""
794796
)
795797

@@ -1105,14 +1107,14 @@ def check_marl_grouping(group_map: dict[str, list[str]], agent_names: list[str])
11051107
raise ValueError(f"Group {group_name} is empty")
11061108
for agent_name in group:
11071109
if agent_name not in found_agents:
1108-
raise ValueError(f"Agent {agent_name} not present in environment")
1110+
raise ValueError(f"Agent {agent_name} wasn't present in environment")
11091111
if not found_agents[agent_name]:
11101112
found_agents[agent_name] = True
11111113
else:
11121114
raise ValueError(f"Agent {agent_name} present more than once")
11131115
for agent_name, found in found_agents.items():
11141116
if not found:
1115-
raise ValueError(f"Agent {agent_name} not found in any group")
1117+
raise ValueError(f"Agent {agent_name} wasn't found in any group")
11161118

11171119

11181120
def _terminated_or_truncated(
@@ -1607,19 +1609,19 @@ def _make_compatible_policy(
16071609
else:
16081610
raise TypeError(
16091611
f"""This error is raised because TorchRL tried to automatically wrap your policy in
1610-
a TensorDictModule. If you're confident the policy can directly process environment outputs, set
1611-
the `trust_policy` argument to `True` in the constructor.
1612-
1613-
Arguments to policy.forward are incompatible with entries in
1614-
env.observation_spec (got incongruent signatures:
1615-
the function signature is {set(sig.parameters)} but the specs have keys {set(next_observation)}).
1616-
If you want TorchRL to automatically wrap your policy with a TensorDictModule
1617-
then the arguments to policy.forward must correspond one-to-one with entries
1618-
in env.observation_spec.
1619-
For more complex behavior and more control you can consider writing your
1620-
own TensorDictModule.
1621-
Check the collector documentation to know more about accepted policies.
1622-
"""
1612+
a TensorDictModule. If you're confident the policy can directly process environment outputs, set
1613+
the `trust_policy` argument to `True` in the constructor.
1614+
1615+
Arguments to policy.forward are incompatible with entries in
1616+
env.observation_spec (got incongruent signatures:
1617+
the function signature is {set(sig.parameters)} but the specs have keys {set(next_observation)}).
1618+
If you want TorchRL to automatically wrap your policy with a TensorDictModule
1619+
then the arguments to policy.forward must correspond one-to-one with entries
1620+
in env.observation_spec.
1621+
For more complex behavior and more control you can consider writing your
1622+
own TensorDictModule.
1623+
Check the collector documentation to know more about accepted policies.
1624+
"""
16231625
)
16241626
return policy
16251627

@@ -1736,5 +1738,5 @@ def __getattr__(self, attr: str) -> Any:
17361738
super().__getattr__(attr)
17371739
except Exception:
17381740
raise AttributeError(
1739-
f"policy not set in {self.__class__.__name__}, cannot access {attr}."
1741+
f"The policy wasn't set in {self.__class__.__name__}, cannot access {attr}."
17401742
)

0 commit comments

Comments
 (0)