|
| 1 | +# AgentGym Rollout Controller Design Document |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +This document outlines the design and implementation of the Rollout Controller for the AgentGym framework. The Rollout Controller extends AgentGym's capabilities by adding support for advanced exploration strategies (Tree of Thoughts, Monte Carlo Tree Search, etc.) and trajectory storage, while maintaining compatibility with the existing architecture. |
| 6 | + |
| 7 | +## Motivation |
| 8 | + |
| 9 | +The standard AgentGym implementation uses a straightforward ReAct approach for agent interaction with environments. While this works well for simple scenarios, more complex reasoning and decision-making often benefit from advanced exploration strategies that consider multiple possible action paths. Additionally, storing and analyzing trajectories is crucial for reinforcement learning and model improvement. |
| 10 | + |
| 11 | +## Architecture |
| 12 | + |
| 13 | +The Rollout Controller architecture consists of three main components: |
| 14 | + |
| 15 | +1. **Rollout Strategies**: Implementations of different exploration algorithms |
| 16 | +2. **Trajectory Storage**: Systems for persisting and retrieving trajectories |
| 17 | +3. **Rollout Controller**: Main controller that integrates strategies and storage with AgentGym |
| 18 | + |
| 19 | +### Integration with AgentGym |
| 20 | + |
| 21 | +The implementation extends the existing AgentGym components rather than replacing them: |
| 22 | + |
| 23 | +- `RolloutController` extends `BaseAgentEnvController` from AgentGym |
| 24 | +- All strategies accept and return `ExperienceOutput` objects for compatibility |
| 25 | +- The controller uses `BaseTask` and `BaseEnvClient` from AgentGym for environment interaction |
| 26 | +``` |
| 27 | + BaseAgentEnvController |
| 28 | + ↑ |
| 29 | + | |
| 30 | + RolloutController ←→ IRolloutStrategy |
| 31 | + | ↑ |
| 32 | + | | |
| 33 | + | BaseRolloutStrategy |
| 34 | + | ↑ |
| 35 | + | | |
| 36 | + | ┌─────┴─────────┐ |
| 37 | + | | | |
| 38 | + | StandardReAct ToT/MCTS/etc. |
| 39 | + | |
| 40 | + ↓ |
| 41 | + ITrajectoryStorage |
| 42 | + ↑ |
| 43 | + ┌───────┴───────┐ |
| 44 | + | | |
| 45 | + MongoDBStorage FileStorage |
| 46 | +``` |
| 47 | +## Components |
| 48 | + |
| 49 | +### Rollout Strategies |
| 50 | + |
| 51 | +All strategies implement the `IRolloutStrategy` interface, ensuring a consistent API: |
| 52 | + |
| 53 | +```python |
| 54 | +class IRolloutStrategy(ABC): |
| 55 | + @abstractmethod |
| 56 | + def execute( |
| 57 | + self, |
| 58 | + model: PreTrainedModel, |
| 59 | + tokenizer: PreTrainedTokenizerBase, |
| 60 | + client: BaseEnvClient, |
| 61 | + initial_observation: str, |
| 62 | + generation_config: Optional[GenerationConfig] = None, |
| 63 | + max_rounds: Optional[int] = None |
| 64 | + ) -> List[ExperienceOutput]: |
| 65 | + """Execute the strategy and return trajectories""" |
| 66 | + pass |
| 67 | +``` |
| 68 | + |
| 69 | +#### Implemented Strategies |
| 70 | + |
| 71 | +1. **StandardReActStrategy**: The default strategy used in AgentGym, which follows a linear path of observation → action → observation. |
| 72 | + |
| 73 | +2. **ToTStrategy (Tree of Thoughts)**: Implements a tree exploration approach where: |
| 74 | + - The agent considers multiple possible actions at each step |
| 75 | + - For each action, it explores the resulting states recursively |
| 76 | + - This creates a tree of potential trajectories |
| 77 | + - Parameters control the breadth (number of branches) and depth of exploration |
| 78 | + |
| 79 | +3. **MCTSStrategy (Monte Carlo Tree Search)**: Implements the MCTS algorithm for more efficient exploration of large action spaces: |
| 80 | + - Selection: Choose promising nodes to explore |
| 81 | + - Expansion: Add new child nodes |
| 82 | + - Simulation: Run rollouts to estimate node value |
| 83 | + - Backpropagation: Update node values based on simulation results |
| 84 | + |
| 85 | +### Trajectory Storage |
| 86 | + |
| 87 | +The `ITrajectoryStorage` interface defines methods for saving and retrieving trajectories: |
| 88 | + |
| 89 | +```python |
| 90 | +class ITrajectoryStorage: |
| 91 | + def save_trajectory(self, env_name, task_id, strategy_name, trajectory, metadata=None) -> str: |
| 92 | + pass |
| 93 | + |
| 94 | + def save_trajectories(self, env_name, task_ids, strategy_name, trajectories, metadata=None) -> List[str]: |
| 95 | + pass |
| 96 | + |
| 97 | + def get_trajectory(self, trajectory_id) -> Optional[Dict]: |
| 98 | + pass |
| 99 | + |
| 100 | + def get_trajectories(self, env_name=None, task_id=None, strategy_name=None, limit=100) -> List[Dict]: |
| 101 | + pass |
| 102 | + |
| 103 | + def get_best_trajectory(self, env_name, task_id) -> Optional[Dict]: |
| 104 | + pass |
| 105 | +``` |
| 106 | + |
| 107 | +#### Implementations |
| 108 | + |
| 109 | +1. **MongoDBTrajectoryStorage**: Stores trajectories in MongoDB for scalable, queryable access. |
| 110 | +2. **FileTrajectoryStorage**: A simpler implementation that stores trajectories in JSONL files. |
| 111 | + |
| 112 | +### Rollout Controller |
| 113 | + |
| 114 | +The `RolloutController` class orchestrates the rollout process: |
| 115 | + |
| 116 | +```python |
| 117 | +class RolloutController(BaseAgentEnvController): |
| 118 | + def __init__( |
| 119 | + self, |
| 120 | + agent: Agent, |
| 121 | + tasks: List[BaseTask], |
| 122 | + strategy: Optional[IRolloutStrategy] = None, |
| 123 | + storage: Optional[ITrajectoryStorage] = None, |
| 124 | + max_workers: int = 10 |
| 125 | + ): |
| 126 | + # initialization... |
| 127 | + |
| 128 | + def rollout( |
| 129 | + self, |
| 130 | + generation_config: Optional[GenerationConfig] = None, |
| 131 | + max_rounds: Optional[int] = None, |
| 132 | + idxs: Optional[List[int]] = None, |
| 133 | + save_to_storage: bool = True, |
| 134 | + parallel: bool = True, |
| 135 | + batch_size: int = 1, |
| 136 | + metadata: Optional[Dict[str, Any]] = None |
| 137 | + ) -> List[ExperienceOutput]: |
| 138 | + # implementation... |
| 139 | +``` |
| 140 | + |
| 141 | +Key features: |
| 142 | +- **Configurable strategy**: Use different exploration strategies for different tasks |
| 143 | +- **Parallel execution**: Process multiple environments concurrently |
| 144 | +- **Trajectory storage**: Automatically save trajectories for later analysis |
| 145 | +- **Batch processing**: Process environments in batches for memory efficiency |
| 146 | + |
| 147 | +## Usage Examples |
| 148 | + |
| 149 | +### Basic Usage with Tree of Thoughts |
| 150 | + |
| 151 | +```python |
| 152 | +from agentenv.controller import Agent |
| 153 | +from agentenv.envs import WebshopTask |
| 154 | +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
| 155 | + |
| 156 | +from rollout_controller import RolloutController |
| 157 | +from strategies import ToTStrategy |
| 158 | +from database import MongoDBTrajectoryStorage |
| 159 | + |
| 160 | +# Load model and tokenizer |
| 161 | +model = AutoModelForCausalLM.from_pretrained("model_path") |
| 162 | +tokenizer = AutoTokenizer.from_pretrained("model_path") |
| 163 | +agent = Agent(model, tokenizer) |
| 164 | + |
| 165 | +# Create task |
| 166 | +task = WebshopTask( |
| 167 | + client_args={"env_server_base": "http://localhost:36001", "data_len": 200}, |
| 168 | + n_clients=1 |
| 169 | +) |
| 170 | + |
| 171 | +# Create storage |
| 172 | +storage = MongoDBTrajectoryStorage() |
| 173 | + |
| 174 | +# Create strategy |
| 175 | +strategy = ToTStrategy(num_branches=3, depth=2) |
| 176 | + |
| 177 | +# Create controller |
| 178 | +controller = RolloutController( |
| 179 | + agent=agent, |
| 180 | + tasks=[task], |
| 181 | + strategy=strategy, |
| 182 | + storage=storage |
| 183 | +) |
| 184 | + |
| 185 | +# Run rollout |
| 186 | +results = controller.rollout( |
| 187 | + generation_config=GenerationConfig(max_length=4096), |
| 188 | + max_rounds=7, |
| 189 | + idxs=[0, 1, 2], # Run on first three tasks |
| 190 | + parallel=True |
| 191 | +) |
| 192 | + |
| 193 | +# Analyze results |
| 194 | +for result in results: |
| 195 | + print(f"Reward: {result.reward}") |
| 196 | +``` |
| 197 | + |
| 198 | +### Switching Strategies |
| 199 | + |
| 200 | +```python |
| 201 | +from strategies import MCTSStrategy |
| 202 | + |
| 203 | +# Switch to MCTS strategy |
| 204 | +mcts_strategy = MCTSStrategy(num_simulations=50, exploration_weight=1.0) |
| 205 | +controller.set_strategy(mcts_strategy) |
| 206 | + |
| 207 | +# Run rollout with new strategy |
| 208 | +results = controller.rollout(idxs=[0, 1, 2]) |
| 209 | +``` |
| 210 | + |
| 211 | +## Implementation Considerations |
| 212 | + |
| 213 | +### Concurrency and Thread Safety |
| 214 | + |
| 215 | +- The controller uses ThreadPoolExecutor for parallel rollouts |
| 216 | +- Each rollout uses a separate environment client instance |
| 217 | +- Careful consideration of thread safety in strategy implementations |
| 218 | + |
| 219 | +### Memory Management |
| 220 | + |
| 221 | +- Batch processing to avoid excessive memory usage |
| 222 | +- Proper cleanup of resources after rollout |
| 223 | +- Copy-on-write for environment branching |
| 224 | + |
| 225 | +### Error Handling |
| 226 | + |
| 227 | +- Robust error handling at multiple levels |
| 228 | +- Failed rollouts don't interrupt the entire process |
| 229 | +- Detailed error reporting |
| 230 | + |
| 231 | +## TODO |
0 commit comments