Skip to content

Commit e92c41f

Browse files
keshavb96Steboss
andauthored
Single controller (#1973)
This PR demonstrates how to use the building blocks in the bridge codebase to implement orchestration through a single controller. In totality we now have: 1. Decoupling of processes that do weight updates / training, prompt dispatching and rollout generation (vLLM). See example in examples/decoupled_synchronous. 2. Async RL support, see example in examples/decoupled_asynchronous 3. Single-controller (sync and async) example demonstrating how to orchestrate the RL training loop + inference offloading with a single-controller using the building blocks of the bridge. --------- Co-authored-by: Steboss <sbosisio@nvidia.com>
1 parent 3abea18 commit e92c41f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+7608
-3457
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Decoupled Asynchronous Example
2+
3+
This example demonstrates an **asynchronous decoupled architecture** for JAX-vLLM inference offloading, designed for high-throughput RL post-training workflows like GRPO.
4+
5+
## Key Differences from Synchronous Version
6+
7+
| Aspect | Synchronous | Asynchronous |
8+
|--------|-------------|--------------|
9+
| **Driver** | JAX Controller | Prompt Dispatcher (autonomous) |
10+
| **Prompt dispatch** | Waits for `sync/weights_ready` signal | Continuous, blocks only on queue backpressure |
11+
| **Weight updates** | Blocking, every iteration | Controlled via `MAX_STALENESS` |
12+
| **Result delivery** | All B×R results in one message | Streamed per-rollout |
13+
| **JAX processing** | Waits for full batch | Processes each prompt group as it completes |
14+
| **Staleness control** | None (always fresh) | Configurable via `--max-staleness` |
15+
16+
## Use Case
17+
18+
This architecture is designed for **RL post-training** workflows (e.g., GRPO) where:
19+
- High throughput is critical
20+
- Controllable staleness is acceptable (vLLM runs ahead by at most N prompts)
21+
- Results should be processed as they become available
22+
- Weight updates are less frequent than in synchronous mode
23+
24+
## Architecture
25+
26+
```
27+
┌─────────────────────────────────────────────────────────────────────────┐
28+
│ Gateway │
29+
│ (gRPC Message Broker) │
30+
│ │
31+
│ - Routes messages between processes via pub/sub topics │
32+
│ - Provides KV store for cross-process coordination │
33+
└─────────────────────────────────────────────────────────────────────────┘
34+
35+
┌─────────────────────────┼─────────────────────────┐
36+
│ │ │
37+
▼ ▼ ▼
38+
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
39+
│ vLLM Worker │ │ JAX Controller │ │Prompt Dispatcher│
40+
│ (vLLM GPUs) │ │ (JAX GPUs) │ │ (CPU only) │
41+
│ │ │ │ │ │
42+
│ - Bounded │◄─────│ - Accumulates │ │ - Autonomous │
43+
│ staleness │ NCCL │ rollout │ │ dispatch loop │
44+
│ control │ │ results │ │ - Blocks on │
45+
│ - Streams │ │ - Pushes weight │ │ backpressure │
46+
│ results │─────►│ updates after │ │ - No sync │
47+
│ - Partial batch │ │ N prompts │ │ signals │
48+
│ processing │ │ │ │ │
49+
└─────────────────┘ └─────────────────┘ └─────────────────┘
50+
```
51+
52+
## The Four Processes
53+
54+
### 1. Gateway (`gateway.py`)
55+
The central message broker (shared with synchronous version).
56+
57+
- **Location**: `jax_inference_offloading/controller/gateway.py`
58+
- **Role**: Routes gRPC messages, provides KV store
59+
- **GPU**: None required
60+
- **New in async**: Supports bounded queues for backpressure
61+
62+
### 2. vLLM Worker (`vllm_worker.py`)
63+
Runs inference with streaming results and bounded staleness control.
64+
65+
- **Location**: `examples/decoupled_asynchronous/vllm_worker.py`
66+
- **Role**: Generates rollouts, streams results, receives weight updates
67+
- **GPU**: Requires vLLM-assigned GPUs
68+
- **Library components used**:
69+
- `RolloutServicer` - Shared with sync version, handles inference and weight updates
70+
- `AsyncRolloutClient` - Adds staleness control, request queuing, partial batch processing
71+
- **Key async behavior**:
72+
- Receives weight updates via `WEIGHT_UPDATES` topic (same mechanism as sync)
73+
- Streams individual `RolloutResult` messages as rollouts complete
74+
- **Bounded staleness**: Limits itself to `MAX_STALENESS` prompts before waiting for weight update
75+
- **Partial batch processing**: Can split batches to hit exact staleness limits
76+
- Queues inference requests when at staleness limit, drains queue after weight update
77+
78+
### 3. JAX Controller (`jax_controller.py`)
79+
Accumulates streamed results and pushes weight updates.
80+
81+
- **Location**: `examples/decoupled_asynchronous/jax_controller.py`
82+
- **Role**: Accumulates rollouts by prompt, triggers weight updates
83+
- **GPU**: Requires JAX-assigned GPUs
84+
- **Library components used**:
85+
- `RolloutAccumulator` - Thread-safe accumulator for grouping rollout results
86+
- `result_consumer_thread` - Background thread for consuming gRPC stream
87+
- **Key async behavior**:
88+
- **Consumer thread** continuously reads from gRPC stream (never blocks)
89+
- **Main thread** polls for completed groups and handles weight updates
90+
- When main thread blocks on NCCL, consumer thread keeps reading (no message loss)
91+
- Groups results by `(batch_id, prompt_index)`
92+
- After N prompts complete (`UPDATE_INTERVAL`), pushes weight update
93+
94+
### 4. Prompt Dispatcher (`prompt_dispatcher.py`)
95+
Autonomously dispatches prompts without waiting for signals.
96+
97+
- **Location**: `examples/decoupled_asynchronous/prompt_dispatcher.py`
98+
- **Role**: Continuously sends inference requests
99+
- **GPU**: None required
100+
- **Key async behavior**:
101+
- No sync signal subscription
102+
- Blocks only when inference queue is full (backpressure)
103+
- Generates unique `batch_id` for result correlation
104+
105+
## Data Flow
106+
107+
### Per-Rollout Streaming
108+
109+
```
110+
vLLM Worker Gateway JAX Controller
111+
│ │ │
112+
│ generate rollouts │ │
113+
│ │ │
114+
│ RolloutResult(P0,R0) │ │
115+
│────────────────────────►│─────────────────────►│
116+
│ │ [accumulate]
117+
│ RolloutResult(P1,R0) │ │
118+
│────────────────────────►│─────────────────────►│
119+
│ │ [accumulate]
120+
│ RolloutResult(P0,R1) │ │
121+
│────────────────────────►│─────────────────────►│
122+
│ │ [P0 complete!]
123+
│ ... │ │
124+
```
125+
126+
### Weight Updates (Same as Sync, Less Frequent)
127+
128+
```
129+
JAX Controller Gateway vLLM Worker
130+
│ │ │
131+
│ [enough prompts done] │ │
132+
│ │ │
133+
│ StartWeightUpdate │ │
134+
│ (via WEIGHT_UPDATES) │ │
135+
│────────────────────────►│─────────────────────►│
136+
│ │ [start NCCL recv]
137+
│ │ │
138+
│◄════════════ NCCL Transfer ════════════════════►│
139+
│ │ │
140+
│ [continue accumulating │ [continue with │
141+
│ rollout results] │ new weights] │
142+
```
143+
144+
## Result Correlation for GRPO
145+
146+
Each streamed `RolloutResult` includes:
147+
- `batch_id`: Unique ID for the inference request
148+
- `prompt_index`: Which prompt within the batch (0-based)
149+
- `rollout_index`: Which rollout for this prompt (0 to num_outputs-1)
150+
151+
The JAX controller uses a **thread-safe accumulator** with a dedicated consumer thread:
152+
153+
```python
154+
# Consumer thread (never blocks, runs independently)
155+
def result_consumer_thread(results_stream, accumulator):
156+
for delivery in results_stream:
157+
result = unpack(delivery)
158+
accumulator.add_result(result) # Thread-safe
159+
160+
# Main thread (can block on NCCL without losing results)
161+
while not done:
162+
groups = accumulator.get_completed_groups(timeout=0.1) # Non-blocking
163+
for batch_id, prompt_idx, group in groups:
164+
# Ready for GRPO: compute rewards, advantages, gradients
165+
process_group(group)
166+
167+
if should_update_weights:
168+
transfer_engine.update_weights(params) # Blocks, but consumer keeps reading
169+
```
170+
171+
This design ensures that **no results are lost** even when the main thread blocks on NCCL weight transfers.
172+
173+
## Bounded Staleness Control
174+
175+
By default, vLLM runs as fast as possible, which can lead to unbounded staleness (vLLM may process many prompts with stale weights before JAX sends an update). The `--max-staleness` option provides controllable staleness:
176+
177+
```bash
178+
./decoupled_async.sh \
179+
--update-interval=5 \
180+
--max-staleness=5 \
181+
...
182+
```
183+
184+
### How It Works
185+
186+
1. vLLM tracks how many prompts it has processed since the last weight update
187+
2. When the count reaches `MAX_STALENESS`, vLLM queues subsequent inference requests
188+
3. When a weight update arrives, vLLM processes it and drains the queue
189+
4. This ensures vLLM never runs more than `MAX_STALENESS` prompts ahead
190+
191+
```
192+
Without staleness control (MAX_STALENESS=0):
193+
vLLM: [INF_1][INF_2]...[INF_10][WU_1][WU_2]...
194+
└────────── All use stale weights ──────┘
195+
196+
With staleness control (MAX_STALENESS=5):
197+
vLLM: [INF_1][INF_2][WU_1][INF_3][INF_4][WU_2]...
198+
└─ 5 prompts ─┘ └─ 5 prompts ─┘
199+
```
200+
201+
### Configuration Constraint
202+
203+
**Simple rule**: `MAX_STALENESS >= UPDATE_INTERVAL`
204+
205+
vLLM supports **partial batch processing** - if a batch of 3 prompts arrives but only 2 fit within the staleness limit, vLLM will:
206+
1. Process 2 prompts immediately
207+
2. Queue the remaining 1 prompt until the next weight update
208+
209+
This means vLLM can process exactly up to `MAX_STALENESS` prompts, regardless of batch size.
210+
211+
**Examples**:
212+
```bash
213+
# Valid: MAX_STALENESS >= UPDATE_INTERVAL
214+
--max-staleness=5 --update-interval=5 # ✓ vLLM processes exactly 5
215+
--max-staleness=5 --update-interval=3 # ✓ vLLM processes 5, JAX updates at 3
216+
--max-staleness=1 --update-interval=1 # ✓ vLLM processes 1 prompt at a time
217+
218+
# Invalid: would deadlock
219+
--max-staleness=3 --update-interval=5 # ✗ vLLM blocks at 3, JAX needs 5
220+
```
221+
222+
The vLLM worker validates this at startup and will error with a helpful message if misconfigured.
223+
224+
## Running the Example
225+
226+
```bash
227+
cd examples/decoupled_asynchronous
228+
229+
./decoupled_async.sh \
230+
--model-path=/path/to/model \
231+
--param-mapping-path=../mappings/llama3_1b_param_mapping.json \
232+
--num-batches=10 \
233+
--batch-size=3 \
234+
--num-rollouts=4 \
235+
--update-interval=10
236+
```
237+
238+
### Required Arguments
239+
- `--model-path` - Path to HuggingFace model checkpoint
240+
- `--param-mapping-path` - Path to JSON parameter mapping file
241+
242+
### Async-Specific Arguments
243+
- `--num-batches=N` - Number of batches to dispatch (default: 10, 0 for infinite)
244+
- `--batch-size=N` - Prompts per batch (default: 3)
245+
- `--num-rollouts=N` - Rollouts per prompt (default: 4)
246+
- `--update-interval=N` - Push weight update after N prompts (default: 10)
247+
- `--max-staleness=N` - Max prompts vLLM processes before requiring weight update (default: 0=unlimited)
248+
- `--max-completed-prompts=N` - Stop after N prompts (default: 100)
249+
- `--dispatch-delay=FLOAT` - Delay between dispatches in seconds (default: 0.0)
250+
251+
### Common Arguments
252+
- `--n-gpus-vllm=N` - GPUs for vLLM (default: 4)
253+
- `--n-gpus-jax=N` - GPUs for JAX (default: 4)
254+
- `--transfer-mode=MODE` - Weight transfer mode: `grouped`/`fused`/`unfused`
255+
- `--gateway-port=PORT` - gRPC port (default: 50051)
256+
- `--debug` - Enable verbose logging
257+
258+
## Customization for GRPO
259+
260+
To add GRPO training logic, modify `jax_controller.py`:
261+
262+
```python
263+
# In the main loop, after getting completed groups from accumulator
264+
for batch_id, prompt_idx, group in accumulator.get_completed_groups():
265+
# group is a list of RolloutResult messages for this prompt
266+
267+
# 1. Compute rewards for each rollout
268+
rewards = compute_rewards(group)
269+
270+
# 2. Compute GRPO advantages (relative within group)
271+
advantages = grpo_advantages(rewards)
272+
273+
# 3. Accumulate gradients
274+
grads = compute_grpo_gradients(group, advantages, params)
275+
accumulated_grads = accumulate(accumulated_grads, grads)
276+
277+
prompts_processed += 1
278+
279+
# 4. After enough prompts, apply gradients and push weights
280+
if prompts_processed >= update_interval:
281+
params = apply_gradients(params, accumulated_grads)
282+
transfer_engine.update_weights(params) # NCCL transfer to vLLM
283+
prompts_processed = 0
284+
accumulated_grads = None
285+
```
286+
287+
## Trade-offs
288+
289+
### Advantages
290+
- **Higher throughput**: vLLM doesn't wait for JAX; dispatch doesn't wait for inference
291+
- **Better GPU utilization**: Both JAX and vLLM can work in parallel
292+
- **Incremental processing**: Process results as they arrive
293+
- **No message loss**: Consumer thread reads results even when main thread blocks on NCCL
294+
- **Controllable staleness**: Use `--max-staleness` to limit how far vLLM runs ahead
295+
296+
### Considerations
297+
- **Slightly stale weights**: vLLM may generate some rollouts with weights from previous update
298+
- **More complex state management**: Need to track pending groups, correlate results
299+
- **Staleness/throughput trade-off**: Lower `MAX_STALENESS` = fresher weights but lower throughput
300+
301+
## Library Components
302+
303+
This example uses the following library components from `jax_inference_offloading`:
304+
305+
### For vLLM Worker
306+
```python
307+
from jax_inference_offloading.controller.rollout_client import (
308+
RolloutServicer, # Handles inference, weight updates, handshake
309+
AsyncRolloutClient, # Adds staleness control and request queuing
310+
make_async_rollout_client, # Factory function
311+
)
312+
```
313+
314+
### For JAX Controller
315+
```python
316+
from jax_inference_offloading.engines import (
317+
RolloutAccumulator, # Thread-safe result accumulator
318+
result_consumer_thread, # Background gRPC consumer
319+
)
320+
```
321+
322+
These components can be reused in custom implementations without copying the example code.
323+
324+
## Environment Variables
325+
326+
Set `VERBOSE_CONSUMER=1` to see each result as the consumer thread receives it (helpful for debugging).

0 commit comments

Comments
 (0)