@@ -9,8 +9,8 @@ The goal is to be able to swap environments in an experiment with little or no e
99even if these environments are simulated using different libraries.
1010TorchRL offers some out-of-the-box environment wrappers under :obj: `torchrl.envs.libs `,
1111which we hope can be easily imitated for other libraries.
12- The parent class :obj: ` EnvBase ` is a :obj : `torch.nn.Module ` subclass that implements
13- some typical environment methods using :obj: ` TensorDict ` as a data organiser. This allows this
12+ The parent class :class: ` torchrl.envs. EnvBase ` is a :class : `torch.nn.Module ` subclass that implements
13+ some typical environment methods using :class: ` tensordict. TensorDict ` as a data organiser. This allows this
1414class to be generic and to handle an arbitrary number of input and outputs, as well as
1515nested or batched data structures.
1616
@@ -25,10 +25,10 @@ Each env will have the following attributes:
2525 This is especially useful for transforms (see below). For parametric environments (e.g.
2626 model-based environments), the device does represent the hardware that will be used to
2727 compute the operations.
28- - :obj: `env.observation_spec `: a :obj: ` CompositeSpec ` object containing all the observation key-spec pairs.
29- - :obj: `env.input_spec `: a :obj: ` CompositeSpec ` object containing all the input keys (:obj: `"action" ` and others).
30- - :obj: `env.action_spec `: a :obj: ` TensorSpec ` object representing the action spec.
31- - :obj: `env.reward_spec `: a :obj: ` TensorSpec ` object representing the reward spec.
28+ - :obj: `env.observation_spec `: a :class: ` torchrl.data. CompositeSpec ` object containing all the observation key-spec pairs.
29+ - :obj: `env.input_spec `: a :class: ` torchrl.data. CompositeSpec ` object containing all the input keys (:obj: `"action" ` and others).
30+ - :obj: `env.action_spec `: a :class: ` torchrl.data. TensorSpec ` object representing the action spec.
31+ - :obj: `env.reward_spec `: a :class: ` torchrl.data. TensorSpec ` object representing the reward spec.
3232
3333Importantly, the environment spec shapes should *not * contain the batch size, e.g.
3434an environment with :obj: `env.batch_size == torch.Size([4]) ` should not have
@@ -38,9 +38,9 @@ an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])` but sim
3838With these, the following methods are implemented:
3939
4040- :obj: `env.reset(tensordict) `: a reset method that may (but not necessarily requires to) take
41- a :obj: ` TensorDict ` input. It return the first tensordict of a rollout, usually
41+ a :class: ` tensordict. TensorDict ` input. It return the first tensordict of a rollout, usually
4242 containing a :obj: `"done" ` state and a set of observations.
43- - :obj: `env.step(tensordict) `: a step method that takes a :obj: ` TensorDict ` input
43+ - :obj: `env.step(tensordict) `: a step method that takes a :class: ` tensordict. TensorDict ` input
4444 containing an input action as well as other inputs (for model-based or stateless
4545 environments, for instance).
4646- :obj: `env.set_seed(integer) `: a seeding method that will return the next seed
@@ -51,7 +51,7 @@ With these, the following methods are implemented:
5151- :obj: `env.rollout(max_steps, policy) `: executes a rollout in the environment for
5252 a maximum number of steps :obj: `max_steps ` and using a policy :obj: `policy `.
5353 The policy should be coded using a :obj: `SafeModule ` (or any other
54- :obj: ` TensorDict `-compatible module).
54+ :class: ` tensordict. TensorDict `-compatible module).
5555
5656
5757.. autosummary ::
@@ -204,6 +204,47 @@ in the environment. The keys to be included in this inverse transform are passed
204204
205205 >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
206206
207+ Cloning transforms
208+ ~~~~~~~~~~~~~~~~~~
209+
210+ Because transforms appended to an environment are "registered" to this environment
211+ through the ``transform.parent `` property, when manipulating transforms we should keep
212+ in mind that the parent may come and go following what is being done with the transform.
213+ Here are some examples: if we get a single transform from a :class: `Compose ` object,
214+ this transform will keep its parent:
215+
216+ >>> third_transform = env.transform[2 ]
217+ >>> assert third_transform.parent is not None
218+
219+ This means that using this transform for another environment is prohibited, as
220+ the other environment would replace the parent and this may lead to unexpected
221+ behviours. Fortunately, the :class: `Transform ` class comes with a :func: `clone `
222+ method that will erase the parent while keeping the identity of all the
223+ registered buffers:
224+
225+ >>> TransformedEnv(base_env, third_transform) # raises an Exception as third_transform already has a parent
226+ >>> TransformedEnv(base_env, third_transform.clone()) # works
227+
228+ On a single process or if the buffers are placed in shared memory, this will
229+ result in all the clone transforms to keep the same behaviour even if the
230+ buffers are changed in place (which is what will happen with the :class: `CatFrames `
231+ transform, for instance). In distributed settings, this may not hold and one
232+ should be careful about the expected behaviour of the cloned transforms in this
233+ context.
234+ Finally, notice that indexing multiple transforms from a :class: `Compose ` transform
235+ may also result in loss of parenthood for these transforms: the reason is that
236+ indexing a :class: `Compose ` transform results in another :class: `Compose ` transform
237+ that does not have a parent environment. Hence, we have to clone the sub-transforms
238+ to be able to create this other composition:
239+
240+ >>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3))
241+ >>> last_two = env.transform[- 2 :]
242+ >>> assert isinstance (last_two, Compose)
243+ >>> assert last_two.parent is None
244+ >>> assert last_two[0 ] is not transform2
245+ >>> assert isinstance (last_two[0 ], transform2) # and the buffers will match
246+ >>> assert last_two[1 ] is not transform3
247+ >>> assert isinstance (last_two[1 ], transform3) # and the buffers will match
207248
208249.. autosummary ::
209250 :toctree: generated/
0 commit comments