Skip to content

Commit 3185040

Browse files
vmoenscursoragent
andcommitted
[Feature] Restructure LLM pip extras for backend flexibility
- Add llm-vllm, llm-sglang, llm-all extras for backend selection - Base llm extra no longer includes inference backend - Update sglang_nccl.py to use SGLang's native NCCL utilities - Remove vLLM dependency from SGLang weight sync code Users can now: - pip install torchrl[llm-vllm] for vLLM backend - pip install torchrl[llm-sglang] for SGLang backend - pip install torchrl[llm-all] for both backends Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: c9a6333 Pull-Request: #3436
1 parent df3a467 commit 3185040

File tree

2 files changed

+33
-34
lines changed

2 files changed

+33
-34
lines changed

pyproject.toml

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot; python_version>='3
9393
open_spiel = ["open_spiel>=1.5"]
9494
brax = ["jax>=0.7.0; python_version>='3.11'", "brax; python_version>='3.11'"]
9595
procgen = ["procgen"]
96+
# Base LLM dependencies (no inference backend - use llm-vllm or llm-sglang)
9697
llm = [
9798
"transformers",
98-
"vllm",
9999
"playwright",
100100
"datasets",
101101
"langdetect",
@@ -107,12 +107,27 @@ llm = [
107107
"einops",
108108
"safetensors",
109109
]
110+
# LLM with vLLM backend
111+
llm-vllm = [
112+
"torchrl[llm]",
113+
"vllm",
114+
]
115+
# LLM with SGLang backend
116+
llm-sglang = [
117+
"torchrl[llm]",
118+
"sglang[all]",
119+
]
120+
# LLM with both backends
121+
llm-all = [
122+
"torchrl[llm]",
123+
"vllm",
124+
"sglang[all]",
125+
]
110126
grpo = [
111127
# Core dependencies for GRPO training
112-
"datasets",
128+
"torchrl[llm-vllm]",
113129
"peft",
114130
"wandb",
115-
"vllm",
116131
"transformers",
117132
"accelerate",
118133
"ray",
@@ -121,10 +136,6 @@ grpo = [
121136
"flash-attn",
122137
"bitsandbytes",
123138
"xformers",
124-
# LLM-related dependencies that may be needed
125-
"nltk",
126-
"langdetect",
127-
"immutabledict",
128139
]
129140
dev = [
130141
"pre-commit",

torchrl/weight_update/llm/sglang_nccl.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -191,33 +191,21 @@ def init_all_workers_group(
191191
# Step 2: Initialize trainer's NCCL communicator
192192
torch.cuda.set_device(self.device)
193193

194-
try:
195-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
196-
from vllm.distributed.utils import StatelessProcessGroup
197-
198-
pg = StatelessProcessGroup.create(
199-
host=self.master_address,
200-
port=self.master_port,
201-
rank=0,
202-
world_size=self.world_size,
203-
)
204-
self._comm_group = PyNcclCommunicator(
205-
pg, device=torch.device(f"cuda:{self.device}")
206-
)
207-
except ImportError:
208-
# Fallback to torch.distributed if vLLM not available
209-
torchrl_logger.warning(
210-
"vLLM not available, falling back to torch.distributed for NCCL init. "
211-
"This may not be compatible with SGLang's NCCL implementation."
212-
)
213-
if not torch.distributed.is_initialized():
214-
torch.distributed.init_process_group(
215-
backend="nccl",
216-
init_method=f"tcp://{self.master_address}:{self.master_port}",
217-
rank=0,
218-
world_size=self.world_size,
219-
)
220-
self._comm_group = torch.distributed.distributed_c10d._get_default_group()
194+
# Use SGLang's native NCCL utilities (no vLLM dependency)
195+
from sglang.srt.distributed.device_communicators.pynccl import (
196+
PyNcclCommunicator,
197+
)
198+
from sglang.srt.distributed.utils import StatelessProcessGroup
199+
200+
pg = StatelessProcessGroup.create(
201+
host=self.master_address,
202+
port=self.master_port,
203+
rank=0,
204+
world_size=self.world_size,
205+
)
206+
self._comm_group = PyNcclCommunicator(
207+
pg, device=torch.device(f"cuda:{self.device}")
208+
)
221209

222210
torchrl_logger.info("NCCL group initialized successfully")
223211

0 commit comments

Comments
 (0)