@@ -58,6 +58,23 @@ With these, the following methods are implemented:
5858- :meth: `env.step `: a step method that takes a :class: `tensordict.TensorDict ` input
5959 containing an input action as well as other inputs (for model-based or stateless
6060 environments, for instance).
61+ - :meth: `env.step_and_maybe_reset `: executes a step, and (partially) resets the
62+ environments if it needs to. It returns the updated input with a ``"next" ``
63+ key containing the data of the next step, as well as a tensordict containing
64+ the input data for the next step (ie, reset or result or
65+ :func: `~torchrl.envs.utils.step_mdp `)
66+ This is done by reading the ``done_keys `` and
67+ assigning a ``"_reset" `` signal to each done state. This method allows
68+ to code non-stopping rollout functions with little effort:
69+
70+ >>> data_ = env.reset()
71+ >>> result = []
72+ >>> for i in range (N):
73+ ... data, data_ = env.step_and_maybe_reset(data_)
74+ ... result.append(data)
75+ ...
76+ >>> result = torch.stack(result)
77+
6178- :meth: `env.set_seed `: a seeding method that will return the next seed
6279 to be used in a multi-env setting. This next seed is deterministically computed
6380 from the preceding one, such that one can seed multiple environments with a different
@@ -169,7 +186,95 @@ one can simply call:
169186 >>> print(a)
170187 9.81
171188
172- It is also possible to reset some but not all of the environments:
189+ TorchRL uses a private ``"_reset" `` key to indicate to the environment which
190+ component (sub-environments or agents) should be reset.
191+ This allows to reset some but not all of the components.
192+
193+ The ``"_reset" `` key has two distinct functionalities:
194+ 1. During a call to :meth: `~.EnvBase._reset `, the ``"_reset" `` key may or may
195+ not be present in the input tensordict. TorchRL's convention is that the
196+ absence of the ``"_reset" `` key at a given ``"done" `` level indicates
197+ a total reset of that level (unless a ``"_reset" `` key was found at a level
198+ above, see details below).
199+ If it is present, it is expected that those entries and only those components
200+ where the ``"_reset" `` entry is ``True `` (along key and shape dimension) will be reset.
201+
202+ The way an environment deals with the ``"_reset" `` keys in its :meth: `~.EnvBase._reset `
203+ method is proper to its class.
204+ Designing an environment that behaves according to ``"_reset" `` inputs is the
205+ developer's responsibility, as TorchRL has no control over the inner logic
206+ of :meth: `~.EnvBase._reset `. Nevertheless, the following point should be
207+ kept in mind when desiging that method.
208+
209+ 2. After a call to :meth: `~.EnvBase._reset `, the output will be masked with the
210+ ``"_reset" `` entries and the output of the previous :meth: `~.EnvBase.step `
211+ will be written wherever the ``"_reset" `` was ``False ``. In practice, this
212+ means that if a ``"_reset" `` modifies data that isn't exposed by it, this
213+ modification will be lost. After this masking operation, the ``"_reset" ``
214+ entries will be erased from the :meth: `~.EnvBase.reset ` outputs.
215+
216+ It must be pointed that ``"_reset" `` is a private key, and it should only be
217+ used when coding specific environment features that are internal facing.
218+ In other words, this should NOT be used outside of the library, and developers
219+ will keep the right to modify the logic of partial resets through ``"_reset" ``
220+ setting without preliminary warranty, as long as they don't affect TorchRL
221+ internal tests.
222+
223+ Finally, the following assumptions are made and should be kept in mind when
224+ designing reset functionalities:
225+
226+ - Each ``"_reset" `` is paired with a ``"done" `` entry (+ ``"terminated" `` and,
227+ possibly, ``"truncated" ``). This means that the following structure is not
228+ allowed: ``TensorDict({"done": done, "nested": {"_reset": reset}}, []) ``, as
229+ the ``"_reset" `` lives at a different nesting level than the ``"done" ``.
230+ - A reset at one level does not preclude the presence of a ``"_reset" `` at lower
231+ levels, but it annihilates its effects. The reason is simply that
232+ whether the ``"_reset" `` at the root level corresponds to an ``all() ``, ``any() ``
233+ or custom call to the nested ``"done" `` entries cannot be known in advance,
234+ and it is explicitly assumed that the ``"_reset" `` at the root was placed
235+ there to superseed the nested values (for an example, have a look at
236+ :class: `~.PettingZooWrapper ` implementation where each group has one or more
237+ ``"done" `` entries associated which is aggregated at the root level with a
238+ ``any `` or ``all `` logic depending on the task).
239+ - When calling :meth: `env.reset(tensordict) ` with a partial ``"_reset" `` entry
240+ that will reset some but not all the done sub-environments, the input data
241+ should contain the data of the sub-environemtns that are __not__ being reset.
242+ The reason for this constrain lies in the fact that the output of the
243+ ``env._reset(data) `` can only be predicted for the entries that are reset.
244+ For the others, TorchRL cannot know in advance if they will be meaningful or
245+ not. For instance, one could perfectly just pad the values of the non-reset
246+ components, in which case the non-reset data will be meaningless and should
247+ be discarded.
248+
249+ Below, we give some examples of the expected effect that ``"_reset" `` keys will
250+ have on an environment returning zeros after reset:
251+
252+ >>> # single reset at the root
253+ >>> data = TensorDict({" val" : [1 , 1 ], " _reset" : [False , True ]}, [])
254+ >>> env.reset(data)
255+ >>> print (data.get(" val" )) # only the second value is 0
256+ tensor([1, 0])
257+ >>> # nested resets
258+ >>> data = TensorDict({
259+ ... (" agent0" , " val" ): [1 , 1 ], (" agent0" , " _reset" ): [False , True ],
260+ ... (" agent1" , " val" ): [2 , 2 ], (" agent1" , " _reset" ): [True , False ],
261+ ... }, [])
262+ >>> env.reset(data)
263+ >>> print (data.get((" agent0" , " val" ))) # only the second value is 0
264+ tensor([1, 0])
265+ >>> print (data.get((" agent1" , " val" ))) # only the second value is 0
266+ tensor([0, 2])
267+ >>> # nested resets are overridden by a "_reset" at the root
268+ >>> data = TensorDict({
269+ ... " _reset" : [True , True ],
270+ ... (" agent0" , " val" ): [1 , 1 ], (" agent0" , " _reset" ): [False , True ],
271+ ... (" agent1" , " val" ): [2 , 2 ], (" agent1" , " _reset" ): [True , False ],
272+ ... }, [])
273+ >>> env.reset(data)
274+ >>> print (data.get((" agent0" , " val" ))) # reset at the root overrides nested
275+ tensor([0, 0])
276+ >>> print (data.get((" agent1" , " val" ))) # reset at the root overrides nested
277+ tensor([0, 0])
173278
174279.. code-block ::
175280 :caption: Parallel environment reset
0 commit comments