@@ -37,6 +37,13 @@ It contains tutorials and the API reference.
3737TorchRL relies on [ ` TensorDict ` ] ( https://github.com/pytorch-labs/tensordict/ ) ,
3838a convenient data structure<sup >(1)</sup > to pass data from
3939one object to another without friction.
40+
41+
42+ Here is an example of how the [ environment API] ( https://pytorch.org/rl/reference/envs.html )
43+ relies on tensordict to carry data from one function to another during a rollout
44+ execution:
45+ ![ Alt Text] ( docs/source/_static/img/rollout.gif )
46+
4047` TensorDict ` makes it easy to re-use pieces of code across environments, models and
4148algorithms. For instance, here's how to code a rollout in TorchRL:
4249 <details >
@@ -156,202 +163,207 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
156163
157164## Features
158165
159- - a generic [ trainer class] ( torchrl/trainers/trainers.py ) <sup >(1)</sup > that
160- executes the aforementioned training loop. Through a hooking mechanism,
161- it also supports any logging or data transformation operation at any given
162- time.
163-
164166- A common [ interface for environments] ( torchrl/envs )
165- which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup >(1)</sup > and state-less execution (e.g. Model-based environments).
166- The [ batched environments] ( torchrl/envs/vec_env.py ) containers allow parallel execution<sup >(2)</sup >.
167- A common pytorch-first class of [ tensor-specification class] ( torchrl/data/tensor_specs.py ) is also provided.
168- <details >
169- <summary>Code</summary>
170-
171- ``` python
172- env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
173- env_parallel = ParallelEnv(4 , env_make) # creates 4 envs in parallel
174- tensordict = env_parallel.rollout(max_steps = 20 , policy = None ) # random rollout (no policy given)
175- assert tensordict.shape == [4 , 20 ] # 4 envs, 20 steps rollout
176- env_parallel.action_spec.is_in(tensordict[" action" ]) # spec check returns True
177- ```
178- < / details>
167+ which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup >(1)</sup > and state-less execution
168+ (e.g. Model-based environments).
169+ The [ batched environments] ( torchrl/envs/vec_env.py ) containers allow parallel execution<sup >(2)</sup >.
170+ A common pytorch-first class of [ tensor-specification class] ( torchrl/data/tensor_specs.py ) is also provided.
171+ TorchRL's environments API is simple but stringent and specific. Check the
172+ [ documentation] ( https://pytorch.org/rl/reference/envs.html )
173+ and [ tutorial] ( https://pytorch.org/rl/tutorials/pendulum.html ) to learn more!
174+ <details >
175+ <summary >Code</summary >
176+
177+ ``` python
178+ env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
179+ env_parallel = ParallelEnv(4 , env_make) # creates 4 envs in parallel
180+ tensordict = env_parallel.rollout(max_steps = 20 , policy = None ) # random rollout (no policy given)
181+ assert tensordict.shape == [4 , 20 ] # 4 envs, 20 steps rollout
182+ env_parallel.action_spec.is_in(tensordict[" action" ]) # spec check returns True
183+ ```
184+ </details >
179185
180186- multiprocess [ data collectors] ( torchrl/collectors/collectors.py ) <sup >(2)</sup > that work synchronously or asynchronously.
181- Through the use of TensorDict, TorchRL' s training loops are made very similar to regular training loops in supervised
182- learning (although the " dataloader" -- read data collector -- is modified on- the- fly):
183- < details>
184- < summary> Code< / summary>
185-
186- ```python
187- env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
188- collector = MultiaSyncDataCollector(
189- [env_make, env_make],
190- policy = policy,
191- devices = [" cuda:0" , " cuda:0" ],
192- total_frames = 10000 ,
193- frames_per_batch = 50 ,
194- ...
195- )
196- for i, tensordict_data in enumerate (collector):
197- loss = loss_module(tensordict_data)
198- loss.backward()
199- optim.step()
200- optim.zero_grad()
201- collector.update_policy_weights_()
202- ```
203- < / details>
187+ Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
188+ learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
189+ <details >
190+ <summary >Code</summary >
191+
192+ ``` python
193+ env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
194+ collector = MultiaSyncDataCollector(
195+ [env_make, env_make],
196+ policy = policy,
197+ devices = [" cuda:0" , " cuda:0" ],
198+ total_frames = 10000 ,
199+ frames_per_batch = 50 ,
200+ ...
201+ )
202+ for i, tensordict_data in enumerate (collector):
203+ loss = loss_module(tensordict_data)
204+ loss.backward()
205+ optim.step()
206+ optim.zero_grad()
207+ collector.update_policy_weights_()
208+ ```
209+ </details >
204210
205211- efficient<sup >(2)</sup > and generic<sup >(1)</sup > [ replay buffers] ( torchrl/data/replay_buffers/replay_buffers.py ) with modularized storage:
206- < details>
207- < summary> Code< / summary>
208-
209- ```python
210- storage = LazyMemmapStorage( # memory-mapped (physical) storage
211- cfg.buffer_size,
212- scratch_dir = " /tmp/"
213- )
214- buffer = TensorDictPrioritizedReplayBuffer(
215- alpha = 0.7 ,
216- beta = 0.5 ,
217- collate_fn = lambda x : x,
218- pin_memory = device != torch.device(" cpu" ),
219- prefetch = 10 , # multi-threaded sampling
220- storage = storage
221- )
222- ```
223- < / details>
212+ <details >
213+ <summary >Code</summary >
214+
215+ ``` python
216+ storage = LazyMemmapStorage( # memory-mapped (physical) storage
217+ cfg.buffer_size,
218+ scratch_dir = " /tmp/"
219+ )
220+ buffer = TensorDictPrioritizedReplayBuffer(
221+ alpha = 0.7 ,
222+ beta = 0.5 ,
223+ collate_fn = lambda x : x,
224+ pin_memory = device != torch.device(" cpu" ),
225+ prefetch = 10 , # multi-threaded sampling
226+ storage = storage
227+ )
228+ ```
229+ </details >
224230
225231- cross-library [ environment transforms] ( torchrl/envs/transforms/transforms.py ) <sup >(1)</sup >,
226- executed on device and in a vectorized fashion< sup> (2 )< / sup> ,
227- which process and prepare the data coming out of the environments to be used by the agent:
228- < details>
229- < summary> Code< / summary>
230-
231- ```python
232- env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
233- env_base = ParallelEnv(4 , env_make, device = " cuda:0" ) # creates 4 envs in parallel
234- env = TransformedEnv(
235- env_base,
236- Compose(
237- ToTensorImage(),
238- ObservationNorm(loc = 0.5 , scale = 1.0 )), # executes the transforms once and on device
239- )
240- tensordict = env.reset()
241- assert tensordict.device == torch.device(" cuda:0" )
242- ```
243- Other transforms include: reward scaling (`RewardScaling` ), shape operations (concatenation of tensors, unsqueezing etc.), contatenation of
244- successive operations (`CatFrames` ), resizing (`Resize` ) and many more.
245-
246- Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
247- easy to add and remove them at will:
248- ```python
249- env.insert_transform(0 , NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
250- ```
251- Nevertheless, transforms can access and execute operations on the parent environment:
252- ```python
253- transform = env.transform[1 ] # gathers the second transform of the list
254- parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
255- ```
256- < / details>
232+ executed on device and in a vectorized fashion<sup >(2)</sup >,
233+ which process and prepare the data coming out of the environments to be used by the agent:
234+ <details >
235+ <summary >Code</summary >
236+
237+ ``` python
238+ env_make = lambda : GymEnv(" Pendulum-v1" , from_pixels = True )
239+ env_base = ParallelEnv(4 , env_make, device = " cuda:0" ) # creates 4 envs in parallel
240+ env = TransformedEnv(
241+ env_base,
242+ Compose(
243+ ToTensorImage(),
244+ ObservationNorm(loc = 0.5 , scale = 1.0 )), # executes the transforms once and on device
245+ )
246+ tensordict = env.reset()
247+ assert tensordict.device == torch.device(" cuda:0" )
248+ ```
249+ Other transforms include: reward scaling (` RewardScaling ` ), shape operations (concatenation of tensors, unsqueezing etc.), contatenation of
250+ successive operations (` CatFrames ` ), resizing (` Resize ` ) and many more.
251+
252+ Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
253+ easy to add and remove them at will:
254+ ``` python
255+ env.insert_transform(0 , NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
256+ ```
257+ Nevertheless, transforms can access and execute operations on the parent environment:
258+ ``` python
259+ transform = env.transform[1 ] # gathers the second transform of the list
260+ parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
261+ ```
262+ </details >
257263
258264- various tools for distributed learning (e.g. [ memory mapped tensors] ( https://github.com/pytorch-labs/tensordict/blob/main/tensordict/memmap.py ) )<sup >(2)</sup >;
259265- various [ architectures] ( torchrl/modules/models/ ) and models (e.g. [ actor-critic] ( torchrl/modules/tensordict_module/actors.py ) )<sup >(1)</sup >:
260- < details>
261- < summary> Code< / summary>
262-
263- ```python
264- # create an nn.Module
265- common_module = ConvNet(
266- bias_last_layer = True ,
267- depth = None ,
268- num_cells = [32 , 64 , 64 ],
269- kernel_sizes = [8 , 4 , 3 ],
270- strides = [4 , 2 , 1 ],
271- )
272- # Wrap it in a SafeModule, indicating what key to read in and where to
273- # write out the output
274- common_module = SafeModule(
275- common_module,
276- in_keys = [" pixels" ],
277- out_keys = [" hidden" ],
278- )
279- # Wrap the policy module in NormalParamsWrapper, such that the output
280- # tensor is split in loc and scale, and scale is mapped onto a positive space
281- policy_module = SafeModule(
282- NormalParamsWrapper(
283- MLP(num_cells = [64 , 64 ], out_features = 32 , activation = nn.ELU )
284- ),
285- in_keys = [" hidden" ],
286- out_keys = [" loc" , " scale" ],
287- )
288- # Use a SafeProbabilisticSequential to combine the SafeModule with a
289- # SafeProbabilisticModule, indicating how to build the
290- # torch.distribution.Distribution object and what to do with it
291- policy_module = SafeProbabilisticSequential( # stochastic policy
292- policy_module,
293- SafeProbabilisticModule(
294- in_keys = [" loc" , " scale" ],
295- out_keys = " action" ,
296- distribution_class = TanhNormal,
297- ),
298- )
299- value_module = MLP(
300- num_cells = [64 , 64 ],
301- out_features = 1 ,
302- activation = nn.ELU ,
303- )
304- # Wrap the policy and value funciton in a common module
305- actor_value = ActorValueOperator(common_module, policy_module, value_module)
306- # standalone policy from this
307- standalone_policy = actor_value.get_policy_operator()
308- ```
309- < / details>
266+ <details >
267+ <summary >Code</summary >
268+
269+ ``` python
270+ # create an nn.Module
271+ common_module = ConvNet(
272+ bias_last_layer = True ,
273+ depth = None ,
274+ num_cells = [32 , 64 , 64 ],
275+ kernel_sizes = [8 , 4 , 3 ],
276+ strides = [4 , 2 , 1 ],
277+ )
278+ # Wrap it in a SafeModule, indicating what key to read in and where to
279+ # write out the output
280+ common_module = SafeModule(
281+ common_module,
282+ in_keys = [" pixels" ],
283+ out_keys = [" hidden" ],
284+ )
285+ # Wrap the policy module in NormalParamsWrapper, such that the output
286+ # tensor is split in loc and scale, and scale is mapped onto a positive space
287+ policy_module = SafeModule(
288+ NormalParamsWrapper(
289+ MLP(num_cells = [64 , 64 ], out_features = 32 , activation = nn.ELU )
290+ ),
291+ in_keys = [" hidden" ],
292+ out_keys = [" loc" , " scale" ],
293+ )
294+ # Use a SafeProbabilisticSequential to combine the SafeModule with a
295+ # SafeProbabilisticModule, indicating how to build the
296+ # torch.distribution.Distribution object and what to do with it
297+ policy_module = SafeProbabilisticSequential( # stochastic policy
298+ policy_module,
299+ SafeProbabilisticModule(
300+ in_keys = [" loc" , " scale" ],
301+ out_keys = " action" ,
302+ distribution_class = TanhNormal,
303+ ),
304+ )
305+ value_module = MLP(
306+ num_cells = [64 , 64 ],
307+ out_features = 1 ,
308+ activation = nn.ELU ,
309+ )
310+ # Wrap the policy and value funciton in a common module
311+ actor_value = ActorValueOperator(common_module, policy_module, value_module)
312+ # standalone policy from this
313+ standalone_policy = actor_value.get_policy_operator()
314+ ```
315+ </details >
310316
311317- exploration [ wrappers] ( torchrl/modules/tensordict_module/exploration.py ) and
312- [modules](torchrl/ modules/ models/ exploration.py) to easily swap between exploration and exploitation< sup> (1 )< / sup> :
313- < details>
314- < summary> Code< / summary>
315-
316- ```python
317- policy_explore = EGreedyWrapper(policy)
318- with set_exploration_mode(" random" ):
319- tensordict = policy_explore(tensordict) # will use eps-greedy
320- with set_exploration_mode(" mode" ):
321- tensordict = policy_explore(tensordict) # will not use eps-greedy
322- ```
323- < / details>
318+ [ modules] ( torchrl/modules/models/exploration.py ) to easily swap between exploration and exploitation<sup >(1)</sup >:
319+ <details >
320+ <summary >Code</summary >
321+
322+ ``` python
323+ policy_explore = EGreedyWrapper(policy)
324+ with set_exploration_mode(" random" ):
325+ tensordict = policy_explore(tensordict) # will use eps-greedy
326+ with set_exploration_mode(" mode" ):
327+ tensordict = policy_explore(tensordict) # will not use eps-greedy
328+ ```
329+ </details >
324330
325331- A series of efficient [ loss modules] ( https://github.com/pytorch/rl/blob/main/torchrl/objectives/costs )
326- and highly vectorized
327- [functional return and advantage](https:// github.com/ pytorch/ rl/ blob/ main/ torchrl/ objectives/ returns/ functional.py)
328- computation.
332+ and highly vectorized
333+ [ functional return and advantage] ( https://github.com/pytorch/rl/blob/main/torchrl/objectives/returns/functional.py )
334+ computation.
335+
336+ <details >
337+ <summary >Code</summary >
329338
330- < details>
331- < summary> Code< / summary>
339+ ### Loss modules
340+ ``` python
341+ from torchrl.objectives import DQNLoss
342+ loss_module = DQNLoss(value_network = value_network, gamma = 0.99 )
343+ tensordict = replay_buffer.sample(batch_size)
344+ loss = loss_module(tensordict)
345+ ```
332346
333- # ## Loss modules
334- ```python
335- from torchrl.objectives import DQNLoss
336- loss_module = DQNLoss(value_network = value_network, gamma = 0.99 )
337- tensordict = replay_buffer.sample(batch_size)
338- loss = loss_module(tensordict)
339- ```
347+ ### Advantage computation
348+ ``` python
349+ from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
350+ advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
351+ ```
340352
341- # ## Advantage computation
342- ```python
343- from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
344- advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
345- ```
353+ </details >
346354
347- < / details>
355+ - a generic [ trainer class] ( torchrl/trainers/trainers.py ) <sup >(1)</sup > that
356+ executes the aforementioned training loop. Through a hooking mechanism,
357+ it also supports any logging or data transformation operation at any given
358+ time.
348359
349360- various [ recipes] ( torchrl/trainers/helpers/models.py ) to build models that
350361 correspond to the environment being deployed.
351362
352363If you feel a feature is missing from the library, please submit an issue!
353364If you would like to contribute to new features, check our [ call for contributions] ( https://github.com/pytorch/rl/issues/509 ) and our [ contribution] ( CONTRIBUTING.md ) page.
354365
366+
355367## Examples, tutorials and demos
356368
357369A series of [ examples] ( examples/ ) are provided with an illustrative purpose:
0 commit comments