-
Notifications
You must be signed in to change notification settings - Fork 231
Description
TL;DR
Replace SkyRL's fragmented inference abstractions (InferenceEngineClient, RayWrappedInferenceEngine, RemoteInferenceEngine) with a single InferenceHandle interface that standardizes on HTTP endpoints. This enables seamless integration with frameworks like Ray serve LLM for easy deployments on WideEP, and Prefill-Decode disaggregation.
Motivation
Current Problems
| Problem | Impact |
|---|---|
| Dual code paths | run_engines_locally creates two completely different flows (Ray actors vs HTTP), doubling maintenance burden |
| Scattered weight sync | Trainer manually orchestrates NCCL groups, rank offsets, IPC handles—complexity that should be server-side |
| Backend-specific logic | if engine_backend == "vllm" / "sglang" checks scattered across 5+ files |
| No clear API boundary | Generator and Trainer both reach into InferenceEngineClient internals |
Why Now?
Ray Serve LLM already provides multi-replica orchestration, fault tolerance, and OpenAI-compatible endpoints. Extending it with weight sync and lifecycle APIs lets SkyRL inherit these capabilities without reimplementing them. It is also a simple path to inheriting wide-EP and PD during rollouts which are essential for high throughput trajectory generation.
Proposal
Single Abstraction
class InferenceHandle(ABC):
"""Unified interface for inference interactions."""
# Data Plane (Generator)
async def chat_completion(request_payload: Dict) -> Dict
async def completion(request_payload: Dict) -> Dict
async def generate(input_batch: InferenceEngineInput) -> InferenceEngineOutput
# Control Plane (Trainer)
async def init_weight_transfer(master_addr, master_port, rank_offset, world_size) -> Dict
async def collective_rpc(model, method, args, kwargs) -> Dict
async def pause() / resume() / sleep() / wakeup()
async def server_info() -> DictClient Implementation
All server patterns use the same client: RemoteInferenceHandle (HTTP). For in-process Ray Serve, RayServeInferenceHandle bypasses HTTP overhead (optional)
Two Server Patterns
| Pattern | Server | Features | Best For |
|---|---|---|---|
| Ray Serve LLM | SkyRLIngress via serve.run() |
WideEP, PD, fault tolerance, sleep/wake | Production training |
| Lightweight Router | N × vllm serve + FastAPI router |
Round-robin, control plane fan-out, weight sync | simple cases, tests, etc. |
Standard Endpoints
All servers expose:
Data Plane: /v1/chat/completions, /v1/completions, /tokenize
Control Plane: /pause, /resume, /sleep, /wakeup, /reset_prefix_cache
Weight Sync: /init_weight_update_communicator, /collective_rpc, /server_info
Integration
Driver Hook
class BasePPOExp:
def setup_inference_handle(self) -> InferenceHandle:
# Default: connect to external server (works with any of the 3 patterns)
return RemoteInferenceHandle(self.cfg.generator.inference_endpoint_url)Migration Path
| Phase | Tasks | Invasiveness |
|---|---|---|
| 1. Foundation | Define InferenceHandle ABC, implement RemoteInferenceHandle, create SkyRLIngress, create Lightweight Router |
Low (additive) |
| 2. Integration | Add setup_inference_handle() hook, update Generator/Trainer |
Medium |
| 3. Validation | E2E tests, deprecate old layers | Low |
Feature flag: _use_inference_handle = True in driver class enables new path. Old code remains functional until deprecated.
Benefits
- Unified: One interface for Generator and Trainer
- HTTP-first: Consistent API regardless of deployment
- Server-managed: Weight sync, lifecycle, caching complexity moves server-side
- Backend-agnostic: Same code works with vLLM, SGLang, Ray Serve LLM. If sglang is needed we can add them behind Ray serve LLM.
- Future-proof: Ready for WideEP, PD, and other advanced patterns
Non-Goals
- Changes to Ray Serve LLM core to enable colocation (separate RFC)
- New weight sync mechanisms (uses existing
collective_rpc) - Breaking changes to existing configs (feature-flagged)