|
| 1 | +# Decoupled Asynchronous Example |
| 2 | + |
| 3 | +This example demonstrates an **asynchronous decoupled architecture** for JAX-vLLM inference offloading, designed for high-throughput RL post-training workflows like GRPO. |
| 4 | + |
| 5 | +## Key Differences from Synchronous Version |
| 6 | + |
| 7 | +| Aspect | Synchronous | Asynchronous | |
| 8 | +|--------|-------------|--------------| |
| 9 | +| **Driver** | JAX Controller | Prompt Dispatcher (autonomous) | |
| 10 | +| **Prompt dispatch** | Waits for `sync/weights_ready` signal | Continuous, blocks only on queue backpressure | |
| 11 | +| **Weight updates** | Blocking, every iteration | Controlled via `MAX_STALENESS` | |
| 12 | +| **Result delivery** | All B×R results in one message | Streamed per-rollout | |
| 13 | +| **JAX processing** | Waits for full batch | Processes each prompt group as it completes | |
| 14 | +| **Staleness control** | None (always fresh) | Configurable via `--max-staleness` | |
| 15 | + |
| 16 | +## Use Case |
| 17 | + |
| 18 | +This architecture is designed for **RL post-training** workflows (e.g., GRPO) where: |
| 19 | +- High throughput is critical |
| 20 | +- Controllable staleness is acceptable (vLLM runs ahead by at most N prompts) |
| 21 | +- Results should be processed as they become available |
| 22 | +- Weight updates are less frequent than in synchronous mode |
| 23 | + |
| 24 | +## Architecture |
| 25 | + |
| 26 | +``` |
| 27 | +┌─────────────────────────────────────────────────────────────────────────┐ |
| 28 | +│ Gateway │ |
| 29 | +│ (gRPC Message Broker) │ |
| 30 | +│ │ |
| 31 | +│ - Routes messages between processes via pub/sub topics │ |
| 32 | +│ - Provides KV store for cross-process coordination │ |
| 33 | +└─────────────────────────────────────────────────────────────────────────┘ |
| 34 | + │ |
| 35 | + ┌─────────────────────────┼─────────────────────────┐ |
| 36 | + │ │ │ |
| 37 | + ▼ ▼ ▼ |
| 38 | +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ |
| 39 | +│ vLLM Worker │ │ JAX Controller │ │Prompt Dispatcher│ |
| 40 | +│ (vLLM GPUs) │ │ (JAX GPUs) │ │ (CPU only) │ |
| 41 | +│ │ │ │ │ │ |
| 42 | +│ - Bounded │◄─────│ - Accumulates │ │ - Autonomous │ |
| 43 | +│ staleness │ NCCL │ rollout │ │ dispatch loop │ |
| 44 | +│ control │ │ results │ │ - Blocks on │ |
| 45 | +│ - Streams │ │ - Pushes weight │ │ backpressure │ |
| 46 | +│ results │─────►│ updates after │ │ - No sync │ |
| 47 | +│ - Partial batch │ │ N prompts │ │ signals │ |
| 48 | +│ processing │ │ │ │ │ |
| 49 | +└─────────────────┘ └─────────────────┘ └─────────────────┘ |
| 50 | +``` |
| 51 | + |
| 52 | +## The Four Processes |
| 53 | + |
| 54 | +### 1. Gateway (`gateway.py`) |
| 55 | +The central message broker (shared with synchronous version). |
| 56 | + |
| 57 | +- **Location**: `jax_inference_offloading/controller/gateway.py` |
| 58 | +- **Role**: Routes gRPC messages, provides KV store |
| 59 | +- **GPU**: None required |
| 60 | +- **New in async**: Supports bounded queues for backpressure |
| 61 | + |
| 62 | +### 2. vLLM Worker (`vllm_worker.py`) |
| 63 | +Runs inference with streaming results and bounded staleness control. |
| 64 | + |
| 65 | +- **Location**: `examples/decoupled_asynchronous/vllm_worker.py` |
| 66 | +- **Role**: Generates rollouts, streams results, receives weight updates |
| 67 | +- **GPU**: Requires vLLM-assigned GPUs |
| 68 | +- **Library components used**: |
| 69 | + - `RolloutServicer` - Shared with sync version, handles inference and weight updates |
| 70 | + - `AsyncRolloutClient` - Adds staleness control, request queuing, partial batch processing |
| 71 | +- **Key async behavior**: |
| 72 | + - Receives weight updates via `WEIGHT_UPDATES` topic (same mechanism as sync) |
| 73 | + - Streams individual `RolloutResult` messages as rollouts complete |
| 74 | + - **Bounded staleness**: Limits itself to `MAX_STALENESS` prompts before waiting for weight update |
| 75 | + - **Partial batch processing**: Can split batches to hit exact staleness limits |
| 76 | + - Queues inference requests when at staleness limit, drains queue after weight update |
| 77 | + |
| 78 | +### 3. JAX Controller (`jax_controller.py`) |
| 79 | +Accumulates streamed results and pushes weight updates. |
| 80 | + |
| 81 | +- **Location**: `examples/decoupled_asynchronous/jax_controller.py` |
| 82 | +- **Role**: Accumulates rollouts by prompt, triggers weight updates |
| 83 | +- **GPU**: Requires JAX-assigned GPUs |
| 84 | +- **Library components used**: |
| 85 | + - `RolloutAccumulator` - Thread-safe accumulator for grouping rollout results |
| 86 | + - `result_consumer_thread` - Background thread for consuming gRPC stream |
| 87 | +- **Key async behavior**: |
| 88 | + - **Consumer thread** continuously reads from gRPC stream (never blocks) |
| 89 | + - **Main thread** polls for completed groups and handles weight updates |
| 90 | + - When main thread blocks on NCCL, consumer thread keeps reading (no message loss) |
| 91 | + - Groups results by `(batch_id, prompt_index)` |
| 92 | + - After N prompts complete (`UPDATE_INTERVAL`), pushes weight update |
| 93 | + |
| 94 | +### 4. Prompt Dispatcher (`prompt_dispatcher.py`) |
| 95 | +Autonomously dispatches prompts without waiting for signals. |
| 96 | + |
| 97 | +- **Location**: `examples/decoupled_asynchronous/prompt_dispatcher.py` |
| 98 | +- **Role**: Continuously sends inference requests |
| 99 | +- **GPU**: None required |
| 100 | +- **Key async behavior**: |
| 101 | + - No sync signal subscription |
| 102 | + - Blocks only when inference queue is full (backpressure) |
| 103 | + - Generates unique `batch_id` for result correlation |
| 104 | + |
| 105 | +## Data Flow |
| 106 | + |
| 107 | +### Per-Rollout Streaming |
| 108 | + |
| 109 | +``` |
| 110 | +vLLM Worker Gateway JAX Controller |
| 111 | + │ │ │ |
| 112 | + │ generate rollouts │ │ |
| 113 | + │ │ │ |
| 114 | + │ RolloutResult(P0,R0) │ │ |
| 115 | + │────────────────────────►│─────────────────────►│ |
| 116 | + │ │ [accumulate] |
| 117 | + │ RolloutResult(P1,R0) │ │ |
| 118 | + │────────────────────────►│─────────────────────►│ |
| 119 | + │ │ [accumulate] |
| 120 | + │ RolloutResult(P0,R1) │ │ |
| 121 | + │────────────────────────►│─────────────────────►│ |
| 122 | + │ │ [P0 complete!] |
| 123 | + │ ... │ │ |
| 124 | +``` |
| 125 | + |
| 126 | +### Weight Updates (Same as Sync, Less Frequent) |
| 127 | + |
| 128 | +``` |
| 129 | +JAX Controller Gateway vLLM Worker |
| 130 | + │ │ │ |
| 131 | + │ [enough prompts done] │ │ |
| 132 | + │ │ │ |
| 133 | + │ StartWeightUpdate │ │ |
| 134 | + │ (via WEIGHT_UPDATES) │ │ |
| 135 | + │────────────────────────►│─────────────────────►│ |
| 136 | + │ │ [start NCCL recv] |
| 137 | + │ │ │ |
| 138 | + │◄════════════ NCCL Transfer ════════════════════►│ |
| 139 | + │ │ │ |
| 140 | + │ [continue accumulating │ [continue with │ |
| 141 | + │ rollout results] │ new weights] │ |
| 142 | +``` |
| 143 | + |
| 144 | +## Result Correlation for GRPO |
| 145 | + |
| 146 | +Each streamed `RolloutResult` includes: |
| 147 | +- `batch_id`: Unique ID for the inference request |
| 148 | +- `prompt_index`: Which prompt within the batch (0-based) |
| 149 | +- `rollout_index`: Which rollout for this prompt (0 to num_outputs-1) |
| 150 | + |
| 151 | +The JAX controller uses a **thread-safe accumulator** with a dedicated consumer thread: |
| 152 | + |
| 153 | +```python |
| 154 | +# Consumer thread (never blocks, runs independently) |
| 155 | +def result_consumer_thread(results_stream, accumulator): |
| 156 | + for delivery in results_stream: |
| 157 | + result = unpack(delivery) |
| 158 | + accumulator.add_result(result) # Thread-safe |
| 159 | + |
| 160 | +# Main thread (can block on NCCL without losing results) |
| 161 | +while not done: |
| 162 | + groups = accumulator.get_completed_groups(timeout=0.1) # Non-blocking |
| 163 | + for batch_id, prompt_idx, group in groups: |
| 164 | + # Ready for GRPO: compute rewards, advantages, gradients |
| 165 | + process_group(group) |
| 166 | + |
| 167 | + if should_update_weights: |
| 168 | + transfer_engine.update_weights(params) # Blocks, but consumer keeps reading |
| 169 | +``` |
| 170 | + |
| 171 | +This design ensures that **no results are lost** even when the main thread blocks on NCCL weight transfers. |
| 172 | + |
| 173 | +## Bounded Staleness Control |
| 174 | + |
| 175 | +By default, vLLM runs as fast as possible, which can lead to unbounded staleness (vLLM may process many prompts with stale weights before JAX sends an update). The `--max-staleness` option provides controllable staleness: |
| 176 | + |
| 177 | +```bash |
| 178 | +./decoupled_async.sh \ |
| 179 | + --update-interval=5 \ |
| 180 | + --max-staleness=5 \ |
| 181 | + ... |
| 182 | +``` |
| 183 | + |
| 184 | +### How It Works |
| 185 | + |
| 186 | +1. vLLM tracks how many prompts it has processed since the last weight update |
| 187 | +2. When the count reaches `MAX_STALENESS`, vLLM queues subsequent inference requests |
| 188 | +3. When a weight update arrives, vLLM processes it and drains the queue |
| 189 | +4. This ensures vLLM never runs more than `MAX_STALENESS` prompts ahead |
| 190 | + |
| 191 | +``` |
| 192 | +Without staleness control (MAX_STALENESS=0): |
| 193 | +vLLM: [INF_1][INF_2]...[INF_10][WU_1][WU_2]... |
| 194 | + └────────── All use stale weights ──────┘ |
| 195 | +
|
| 196 | +With staleness control (MAX_STALENESS=5): |
| 197 | +vLLM: [INF_1][INF_2][WU_1][INF_3][INF_4][WU_2]... |
| 198 | + └─ 5 prompts ─┘ └─ 5 prompts ─┘ |
| 199 | +``` |
| 200 | + |
| 201 | +### Configuration Constraint |
| 202 | + |
| 203 | +**Simple rule**: `MAX_STALENESS >= UPDATE_INTERVAL` |
| 204 | + |
| 205 | +vLLM supports **partial batch processing** - if a batch of 3 prompts arrives but only 2 fit within the staleness limit, vLLM will: |
| 206 | +1. Process 2 prompts immediately |
| 207 | +2. Queue the remaining 1 prompt until the next weight update |
| 208 | + |
| 209 | +This means vLLM can process exactly up to `MAX_STALENESS` prompts, regardless of batch size. |
| 210 | + |
| 211 | +**Examples**: |
| 212 | +```bash |
| 213 | +# Valid: MAX_STALENESS >= UPDATE_INTERVAL |
| 214 | +--max-staleness=5 --update-interval=5 # ✓ vLLM processes exactly 5 |
| 215 | +--max-staleness=5 --update-interval=3 # ✓ vLLM processes 5, JAX updates at 3 |
| 216 | +--max-staleness=1 --update-interval=1 # ✓ vLLM processes 1 prompt at a time |
| 217 | + |
| 218 | +# Invalid: would deadlock |
| 219 | +--max-staleness=3 --update-interval=5 # ✗ vLLM blocks at 3, JAX needs 5 |
| 220 | +``` |
| 221 | + |
| 222 | +The vLLM worker validates this at startup and will error with a helpful message if misconfigured. |
| 223 | + |
| 224 | +## Running the Example |
| 225 | + |
| 226 | +```bash |
| 227 | +cd examples/decoupled_asynchronous |
| 228 | + |
| 229 | +./decoupled_async.sh \ |
| 230 | + --model-path=/path/to/model \ |
| 231 | + --param-mapping-path=../mappings/llama3_1b_param_mapping.json \ |
| 232 | + --num-batches=10 \ |
| 233 | + --batch-size=3 \ |
| 234 | + --num-rollouts=4 \ |
| 235 | + --update-interval=10 |
| 236 | +``` |
| 237 | + |
| 238 | +### Required Arguments |
| 239 | +- `--model-path` - Path to HuggingFace model checkpoint |
| 240 | +- `--param-mapping-path` - Path to JSON parameter mapping file |
| 241 | + |
| 242 | +### Async-Specific Arguments |
| 243 | +- `--num-batches=N` - Number of batches to dispatch (default: 10, 0 for infinite) |
| 244 | +- `--batch-size=N` - Prompts per batch (default: 3) |
| 245 | +- `--num-rollouts=N` - Rollouts per prompt (default: 4) |
| 246 | +- `--update-interval=N` - Push weight update after N prompts (default: 10) |
| 247 | +- `--max-staleness=N` - Max prompts vLLM processes before requiring weight update (default: 0=unlimited) |
| 248 | +- `--max-completed-prompts=N` - Stop after N prompts (default: 100) |
| 249 | +- `--dispatch-delay=FLOAT` - Delay between dispatches in seconds (default: 0.0) |
| 250 | + |
| 251 | +### Common Arguments |
| 252 | +- `--n-gpus-vllm=N` - GPUs for vLLM (default: 4) |
| 253 | +- `--n-gpus-jax=N` - GPUs for JAX (default: 4) |
| 254 | +- `--transfer-mode=MODE` - Weight transfer mode: `grouped`/`fused`/`unfused` |
| 255 | +- `--gateway-port=PORT` - gRPC port (default: 50051) |
| 256 | +- `--debug` - Enable verbose logging |
| 257 | + |
| 258 | +## Customization for GRPO |
| 259 | + |
| 260 | +To add GRPO training logic, modify `jax_controller.py`: |
| 261 | + |
| 262 | +```python |
| 263 | +# In the main loop, after getting completed groups from accumulator |
| 264 | +for batch_id, prompt_idx, group in accumulator.get_completed_groups(): |
| 265 | + # group is a list of RolloutResult messages for this prompt |
| 266 | + |
| 267 | + # 1. Compute rewards for each rollout |
| 268 | + rewards = compute_rewards(group) |
| 269 | + |
| 270 | + # 2. Compute GRPO advantages (relative within group) |
| 271 | + advantages = grpo_advantages(rewards) |
| 272 | + |
| 273 | + # 3. Accumulate gradients |
| 274 | + grads = compute_grpo_gradients(group, advantages, params) |
| 275 | + accumulated_grads = accumulate(accumulated_grads, grads) |
| 276 | + |
| 277 | + prompts_processed += 1 |
| 278 | + |
| 279 | + # 4. After enough prompts, apply gradients and push weights |
| 280 | + if prompts_processed >= update_interval: |
| 281 | + params = apply_gradients(params, accumulated_grads) |
| 282 | + transfer_engine.update_weights(params) # NCCL transfer to vLLM |
| 283 | + prompts_processed = 0 |
| 284 | + accumulated_grads = None |
| 285 | +``` |
| 286 | + |
| 287 | +## Trade-offs |
| 288 | + |
| 289 | +### Advantages |
| 290 | +- **Higher throughput**: vLLM doesn't wait for JAX; dispatch doesn't wait for inference |
| 291 | +- **Better GPU utilization**: Both JAX and vLLM can work in parallel |
| 292 | +- **Incremental processing**: Process results as they arrive |
| 293 | +- **No message loss**: Consumer thread reads results even when main thread blocks on NCCL |
| 294 | +- **Controllable staleness**: Use `--max-staleness` to limit how far vLLM runs ahead |
| 295 | + |
| 296 | +### Considerations |
| 297 | +- **Slightly stale weights**: vLLM may generate some rollouts with weights from previous update |
| 298 | +- **More complex state management**: Need to track pending groups, correlate results |
| 299 | +- **Staleness/throughput trade-off**: Lower `MAX_STALENESS` = fresher weights but lower throughput |
| 300 | + |
| 301 | +## Library Components |
| 302 | + |
| 303 | +This example uses the following library components from `jax_inference_offloading`: |
| 304 | + |
| 305 | +### For vLLM Worker |
| 306 | +```python |
| 307 | +from jax_inference_offloading.controller.rollout_client import ( |
| 308 | + RolloutServicer, # Handles inference, weight updates, handshake |
| 309 | + AsyncRolloutClient, # Adds staleness control and request queuing |
| 310 | + make_async_rollout_client, # Factory function |
| 311 | +) |
| 312 | +``` |
| 313 | + |
| 314 | +### For JAX Controller |
| 315 | +```python |
| 316 | +from jax_inference_offloading.engines import ( |
| 317 | + RolloutAccumulator, # Thread-safe result accumulator |
| 318 | + result_consumer_thread, # Background gRPC consumer |
| 319 | +) |
| 320 | +``` |
| 321 | + |
| 322 | +These components can be reused in custom implementations without copying the example code. |
| 323 | + |
| 324 | +## Environment Variables |
| 325 | + |
| 326 | +Set `VERBOSE_CONSUMER=1` to see each result as the consumer thread receives it (helpful for debugging). |
0 commit comments