Skip to content

Commit 75ca4b4

Browse files
authored
[Feature] PPOTrainer (#3117)
1 parent 7b74a9b commit 75ca4b4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+7249
-169
lines changed

.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ dependencies:
2727
- ray
2828
- av
2929
- h5py
30+
- numpy<2.0.0

README.md

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,33 @@
2525

2626
## 🚀 What's New
2727

28+
### 🚀 **Command-Line Training Interface** - Train RL Agents Without Writing Code! (Experimental)
29+
30+
TorchRL now provides a **powerful command-line interface** that lets you train state-of-the-art RL agents with simple bash commands! No Python scripting required - just run training with customizable parameters:
31+
32+
- 🎯 **One-Command Training**: `python sota-implementations/ppo_trainer/train.py`
33+
- ⚙️ **Full Customization**: Override any parameter via command line: `trainer.total_frames=2000000 optimizer.lr=0.0003`
34+
- 🌍 **Multi-Environment Support**: Switch between Gym, Brax, DM Control, and more with `env=gym training_env.create_env_fn.base_env.env_name=HalfCheetah-v4`
35+
- 📊 **Built-in Logging**: TensorBoard, Weights & Biases, CSV logging out of the box
36+
- 🔧 **Hydra-Powered**: Leverages Hydra's powerful configuration system for maximum flexibility
37+
- 🏃‍♂️ **Production Ready**: Same robust training pipeline as our SOTA implementations
38+
39+
**Perfect for**: Researchers, practitioners, and anyone who wants to train RL agents without diving into implementation details.
40+
41+
⚠️ **Note**: This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation!
42+
43+
📋 **Prerequisites**: The training interface requires Hydra for configuration management. Install with:
44+
```bash
45+
pip install "torchrl[utils]"
46+
# or manually:
47+
pip install hydra-core omegaconf
48+
```
49+
50+
Check out the [complete CLI documentation](https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer) to get started!
51+
2852
### LLM API - Complete Framework for Language Model Fine-tuning
2953

30-
TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
54+
TorchRL also includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
3155

3256
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines - more to come!
3357
- 💬 **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
@@ -74,6 +98,67 @@ for data in collector:
7498

7599
</details>
76100

101+
### 🧪 PPOTrainer (Experimental) - High-Level Training Interface
102+
103+
TorchRL now includes an **experimental PPOTrainer** that provides a complete, configurable PPO training solution! This prototype feature combines TorchRL's modular components into a cohesive training system with sensible defaults:
104+
105+
- 🎯 **Complete Training Pipeline**: Handles environment setup, data collection, loss computation, and optimization automatically
106+
- ⚙️ **Extensive Configuration**: Comprehensive Hydra-based config system for easy experimentation and hyperparameter tuning
107+
- 📊 **Built-in Logging**: Automatic tracking of rewards, actions, episode completion rates, and training statistics
108+
- 🔧 **Modular Design**: Built on existing TorchRL components (collectors, losses, replay buffers) for maximum flexibility
109+
- 📝 **Minimal Code**: Complete SOTA implementation in [just ~20 lines](sota-implementations/ppo_trainer/train.py)!
110+
111+
**Working Example**: See [`sota-implementations/ppo_trainer/`](sota-implementations/ppo_trainer/) for a complete, working PPO implementation that trains on Pendulum-v1 with full Hydra configuration support.
112+
113+
**Prerequisites**: Requires Hydra for configuration management: `pip install "torchrl[utils]"`
114+
115+
<details>
116+
<summary>Complete Training Script (sota-implementations/ppo_trainer/train.py)</summary>
117+
118+
```python
119+
import hydra
120+
from torchrl.trainers.algorithms.configs import *
121+
122+
@hydra.main(config_path="config", config_name="config", version_base="1.1")
123+
def main(cfg):
124+
trainer = hydra.utils.instantiate(cfg.trainer)
125+
trainer.train()
126+
127+
if __name__ == "__main__":
128+
main()
129+
```
130+
*Complete PPO training in ~20 lines with full configurability.*
131+
132+
</details>
133+
134+
<details>
135+
<summary>API Usage Examples</summary>
136+
137+
```bash
138+
# Basic usage - train PPO on Pendulum-v1 with default settings
139+
python sota-implementations/ppo_trainer/train.py
140+
141+
# Custom configuration with command-line overrides
142+
python sota-implementations/ppo_trainer/train.py \
143+
trainer.total_frames=2000000 \
144+
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
145+
networks.policy_network.num_cells=[256,256] \
146+
optimizer.lr=0.0003
147+
148+
# Use different environment and logger
149+
python sota-implementations/ppo_trainer/train.py \
150+
env=gym \
151+
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
152+
logger=tensorboard
153+
154+
# See all available options
155+
python sota-implementations/ppo_trainer/train.py --help
156+
```
157+
158+
</details>
159+
160+
**Future Plans**: Additional algorithm trainers (SAC, TD3, DQN) and full integration of all TorchRL components within the configuration system are planned for upcoming releases.
161+
77162
## Key features
78163

79164
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
@@ -932,7 +1017,7 @@ source torchrl/bin/activate # On Windows use: venv\Scripts\activate
9321017
Or create a conda environment where the packages will be installed.
9331018

9341019
```
935-
conda create --name torchrl python=3.9
1020+
conda create --name torchrl python=3.10
9361021
conda activate torchrl
9371022
```
9381023

@@ -945,7 +1030,12 @@ install the latest (nightly) PyTorch release or the latest stable version of PyT
9451030
See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
9461031
including `pip3` or other special installation instructions.
9471032

948-
TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"` etc.
1033+
TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"`, `"torchrl[utils]"` etc.
1034+
1035+
For the experimental training interface and configuration system, install:
1036+
```bash
1037+
pip3 install "torchrl[utils]" # Includes hydra-core and other utilities
1038+
```
9491039

9501040
#### Torchrl
9511041

@@ -989,7 +1079,7 @@ Importantly, the nightly builds require the nightly builds of PyTorch too.
9891079
Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
9901080

9911081

992-
**Disclaimer**: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not
1082+
**Disclaimer**: As of today, TorchRL requires Python 3.10+ and is roughly compatible with any pytorch version >= 2.1. Installing it will not
9931083
directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
9941084
PyTorch to be installed and we are working hard to loosen that requirement.
9951085
The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.

benchmarks/test_collectors_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import torch.cuda
1010
import tqdm
1111

12-
from torchrl.collectors import SyncDataCollector
13-
from torchrl.collectors.collectors import (
12+
from torchrl.collectors import (
1413
MultiaSyncDataCollector,
1514
MultiSyncDataCollector,
15+
SyncDataCollector,
1616
)
1717
from torchrl.data import LazyTensorStorage, ReplayBuffer
1818
from torchrl.data.utils import CloudpickleWrapper

docs/source/reference/collectors.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torchrl.collectors import SyncDataCollector.. currentmodule:: torchrl.collectors
1+
.. currentmodule:: torchrl.collectors
22

33
torchrl.collectors package
44
==========================

0 commit comments

Comments
 (0)