Skip to content

Commit b0c6ddf

Browse files
realtmxiKunlun-Zhu
authored andcommitted
add roll out control
1 parent ae20111 commit b0c6ddf

File tree

6 files changed

+1441
-9
lines changed

6 files changed

+1441
-9
lines changed

README.md

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Setting up LLM agent environment for online RL tunning.
7979
Connect to specialized reasoning models such as deepseek-r1, QwQ-32B for more complex inference tasks to collect comprehensive agent trajectories.
8080

8181
3. RL-Tuning Model Paradigm
82-
Provide an RL fine-tuning approach for customizing the agents behavior in our agent environment.
82+
Provide an RL fine-tuning approach for customizing the agent's behavior in our agent environment.
8383

8484
4. Test on Agent Benchmarks
8585
Evaluate our framework on agentic benchmark such as Webshop, GAIA, OSWorld, AgentBench
@@ -210,20 +210,21 @@ We are still laboriously developing this part, welcome feedback.
210210
First, create a conda environment and activate it:
211211

212212
```bash
213+
# Create a new conda environment
213214
conda create -n openmanus-rl python=3.11 -y
214215
conda activate openmanus-rl
215216
```
216217

217218
Then, install the required dependencies:
218219

219-
220220
```bash
221-
# install torch [or you can skip this step and let vllm to install the correct version for you]
221+
# Install PyTorch with CUDA support
222222
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
223-
# install vllm
224-
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
225223

226-
# verl
224+
# Install vllm for efficient inference
225+
pip3 install vllm==0.6.3
226+
227+
# Install the main package
227228
pip install -e .
228229

229230
# flash attention 2
@@ -237,6 +238,38 @@ cd agentenv
237238
pip install -e .
238239
```
239240

241+
## Environment Setup
242+
243+
### WebShop Environment Setup as an example, more environment could be found on the agentgym
244+
245+
To set up the WebShop environment for evaluation:
246+
247+
```bash
248+
# Change to the agentenv-webshop directory
249+
cd agentenv-webshop
250+
251+
# Create a new conda environment for WebShop
252+
conda env create -n webshop -f environment.yml
253+
conda activate webshop
254+
255+
# Setup the environment
256+
bash ./setup.sh
257+
```
258+
259+
### Launching the WebShop Server
260+
261+
After setting up the environment, you can launch the WebShop server:
262+
263+
```bash
264+
# Make sure the webshop conda environment is activated
265+
conda activate webshop
266+
267+
# Launch the server (default port: 36001)
268+
webshop --port 36001
269+
```
270+
271+
Note: The WebShop environment requires specific versions of Python, PyTorch, Faiss, and Java. The setup script will handle these dependencies automatically.
272+
240273
## Quick start
241274

242275
Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus.
@@ -246,9 +279,6 @@ Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipe
246279
From https://huggingface.co/datasets/CharlieDreemur/OpenManus-RL
247280

248281
(3) Launch a local AgentGym server.
249-
```bash
250-
todo here
251-
```
252282

253283
(4) Run RL training (PPO) with Llama-3.2-3b-base.
254284
```bash
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# AgentGym Rollout Controller Design Document
2+
3+
## Overview
4+
5+
This document outlines the design and implementation of the Rollout Controller for the AgentGym framework. The Rollout Controller extends AgentGym's capabilities by adding support for advanced exploration strategies (Tree of Thoughts, Monte Carlo Tree Search, etc.) and trajectory storage, while maintaining compatibility with the existing architecture.
6+
7+
## Motivation
8+
9+
The standard AgentGym implementation uses a straightforward ReAct approach for agent interaction with environments. While this works well for simple scenarios, more complex reasoning and decision-making often benefit from advanced exploration strategies that consider multiple possible action paths. Additionally, storing and analyzing trajectories is crucial for reinforcement learning and model improvement.
10+
11+
## Architecture
12+
13+
The Rollout Controller architecture consists of three main components:
14+
15+
1. **Rollout Strategies**: Implementations of different exploration algorithms
16+
2. **Trajectory Storage**: Systems for persisting and retrieving trajectories
17+
3. **Rollout Controller**: Main controller that integrates strategies and storage with AgentGym
18+
19+
### Integration with AgentGym
20+
21+
The implementation extends the existing AgentGym components rather than replacing them:
22+
23+
- `RolloutController` extends `BaseAgentEnvController` from AgentGym
24+
- All strategies accept and return `ExperienceOutput` objects for compatibility
25+
- The controller uses `BaseTask` and `BaseEnvClient` from AgentGym for environment interaction
26+
```
27+
BaseAgentEnvController
28+
29+
|
30+
RolloutController ←→ IRolloutStrategy
31+
| ↑
32+
| |
33+
| BaseRolloutStrategy
34+
| ↑
35+
| |
36+
| ┌─────┴─────────┐
37+
| | |
38+
| StandardReAct ToT/MCTS/etc.
39+
|
40+
41+
ITrajectoryStorage
42+
43+
┌───────┴───────┐
44+
| |
45+
MongoDBStorage FileStorage
46+
```
47+
## Components
48+
49+
### Rollout Strategies
50+
51+
All strategies implement the `IRolloutStrategy` interface, ensuring a consistent API:
52+
53+
```python
54+
class IRolloutStrategy(ABC):
55+
@abstractmethod
56+
def execute(
57+
self,
58+
model: PreTrainedModel,
59+
tokenizer: PreTrainedTokenizerBase,
60+
client: BaseEnvClient,
61+
initial_observation: str,
62+
generation_config: Optional[GenerationConfig] = None,
63+
max_rounds: Optional[int] = None
64+
) -> List[ExperienceOutput]:
65+
"""Execute the strategy and return trajectories"""
66+
pass
67+
```
68+
69+
#### Implemented Strategies
70+
71+
1. **StandardReActStrategy**: The default strategy used in AgentGym, which follows a linear path of observation → action → observation.
72+
73+
2. **ToTStrategy (Tree of Thoughts)**: Implements a tree exploration approach where:
74+
- The agent considers multiple possible actions at each step
75+
- For each action, it explores the resulting states recursively
76+
- This creates a tree of potential trajectories
77+
- Parameters control the breadth (number of branches) and depth of exploration
78+
79+
3. **MCTSStrategy (Monte Carlo Tree Search)**: Implements the MCTS algorithm for more efficient exploration of large action spaces:
80+
- Selection: Choose promising nodes to explore
81+
- Expansion: Add new child nodes
82+
- Simulation: Run rollouts to estimate node value
83+
- Backpropagation: Update node values based on simulation results
84+
85+
### Trajectory Storage
86+
87+
The `ITrajectoryStorage` interface defines methods for saving and retrieving trajectories:
88+
89+
```python
90+
class ITrajectoryStorage:
91+
def save_trajectory(self, env_name, task_id, strategy_name, trajectory, metadata=None) -> str:
92+
pass
93+
94+
def save_trajectories(self, env_name, task_ids, strategy_name, trajectories, metadata=None) -> List[str]:
95+
pass
96+
97+
def get_trajectory(self, trajectory_id) -> Optional[Dict]:
98+
pass
99+
100+
def get_trajectories(self, env_name=None, task_id=None, strategy_name=None, limit=100) -> List[Dict]:
101+
pass
102+
103+
def get_best_trajectory(self, env_name, task_id) -> Optional[Dict]:
104+
pass
105+
```
106+
107+
#### Implementations
108+
109+
1. **MongoDBTrajectoryStorage**: Stores trajectories in MongoDB for scalable, queryable access.
110+
2. **FileTrajectoryStorage**: A simpler implementation that stores trajectories in JSONL files.
111+
112+
### Rollout Controller
113+
114+
The `RolloutController` class orchestrates the rollout process:
115+
116+
```python
117+
class RolloutController(BaseAgentEnvController):
118+
def __init__(
119+
self,
120+
agent: Agent,
121+
tasks: List[BaseTask],
122+
strategy: Optional[IRolloutStrategy] = None,
123+
storage: Optional[ITrajectoryStorage] = None,
124+
max_workers: int = 10
125+
):
126+
# initialization...
127+
128+
def rollout(
129+
self,
130+
generation_config: Optional[GenerationConfig] = None,
131+
max_rounds: Optional[int] = None,
132+
idxs: Optional[List[int]] = None,
133+
save_to_storage: bool = True,
134+
parallel: bool = True,
135+
batch_size: int = 1,
136+
metadata: Optional[Dict[str, Any]] = None
137+
) -> List[ExperienceOutput]:
138+
# implementation...
139+
```
140+
141+
Key features:
142+
- **Configurable strategy**: Use different exploration strategies for different tasks
143+
- **Parallel execution**: Process multiple environments concurrently
144+
- **Trajectory storage**: Automatically save trajectories for later analysis
145+
- **Batch processing**: Process environments in batches for memory efficiency
146+
147+
## Usage Examples
148+
149+
### Basic Usage with Tree of Thoughts
150+
151+
```python
152+
from agentenv.controller import Agent
153+
from agentenv.envs import WebshopTask
154+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
155+
156+
from rollout_controller import RolloutController
157+
from strategies import ToTStrategy
158+
from database import MongoDBTrajectoryStorage
159+
160+
# Load model and tokenizer
161+
model = AutoModelForCausalLM.from_pretrained("model_path")
162+
tokenizer = AutoTokenizer.from_pretrained("model_path")
163+
agent = Agent(model, tokenizer)
164+
165+
# Create task
166+
task = WebshopTask(
167+
client_args={"env_server_base": "http://localhost:36001", "data_len": 200},
168+
n_clients=1
169+
)
170+
171+
# Create storage
172+
storage = MongoDBTrajectoryStorage()
173+
174+
# Create strategy
175+
strategy = ToTStrategy(num_branches=3, depth=2)
176+
177+
# Create controller
178+
controller = RolloutController(
179+
agent=agent,
180+
tasks=[task],
181+
strategy=strategy,
182+
storage=storage
183+
)
184+
185+
# Run rollout
186+
results = controller.rollout(
187+
generation_config=GenerationConfig(max_length=4096),
188+
max_rounds=7,
189+
idxs=[0, 1, 2], # Run on first three tasks
190+
parallel=True
191+
)
192+
193+
# Analyze results
194+
for result in results:
195+
print(f"Reward: {result.reward}")
196+
```
197+
198+
### Switching Strategies
199+
200+
```python
201+
from strategies import MCTSStrategy
202+
203+
# Switch to MCTS strategy
204+
mcts_strategy = MCTSStrategy(num_simulations=50, exploration_weight=1.0)
205+
controller.set_strategy(mcts_strategy)
206+
207+
# Run rollout with new strategy
208+
results = controller.rollout(idxs=[0, 1, 2])
209+
```
210+
211+
## Implementation Considerations
212+
213+
### Concurrency and Thread Safety
214+
215+
- The controller uses ThreadPoolExecutor for parallel rollouts
216+
- Each rollout uses a separate environment client instance
217+
- Careful consideration of thread safety in strategy implementations
218+
219+
### Memory Management
220+
221+
- Batch processing to avoid excessive memory usage
222+
- Proper cleanup of resources after rollout
223+
- Copy-on-write for environment branching
224+
225+
### Error Handling
226+
227+
- Robust error handling at multiple levels
228+
- Failed rollouts don't interrupt the entire process
229+
- Detailed error reporting
230+
231+
## TODO

0 commit comments

Comments
 (0)