Skip to content

Commit 2f7a494

Browse files
vmoenscursoragent
andcommitted
[Feature] Add SGLang backend support to GRPO
- Add inference_model.backend config option ("vllm" or "sglang") - Refactor get_inference_model() to support both backends - Refactor make_weight_sync_scheme() to support both backends - Add _get_sglang_inference_model() for SGLang backend - Add _make_sglang_weight_sync_scheme() for SGLang weight sync Users can now run GRPO with either vLLM or SGLang: inference_model: backend: "sglang" # or "vllm" (default) Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 8c27963 Pull-Request: #3437
1 parent 89773e9 commit 2f7a494

File tree

5 files changed

+154
-26
lines changed

5 files changed

+154
-26
lines changed

sota-implementations/grpo/config/grpo_gsm8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ train_model:
8888

8989
# Inference model configuration
9090
inference_model:
91+
backend: "vllm" # Inference backend: "vllm" or "sglang"
9192
num_devices: 1 # Number of devices to use
9293
quantization:
9394
enabled: false # Enable 4-bit quantization for base model

sota-implementations/grpo/config/grpo_ifeval.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ train_model:
8888

8989
# Inference model configuration
9090
inference_model:
91+
backend: "vllm" # Inference backend: "vllm" or "sglang"
9192
num_devices: 2 # Number of devices to use
9293
quantization:
9394
enabled: false # Enable 4-bit quantization for base model

sota-implementations/grpo/grpo-async.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ def train(
109109
if cfg.model.compile:
110110
loss_fn = torch.compile(loss_fn)
111111

112-
vllm_engine = inference_policy.model
112+
inference_engine = inference_policy.model
113113

114114
# Create weight sync scheme for the collectors
115-
weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
115+
weight_sync_scheme = make_weight_sync_scheme(engine=inference_engine, cfg=cfg)
116116

117117
# Set up weight sync scheme for collectors
118118
# Note: We need to get the sender after the collectors are created
@@ -127,7 +127,7 @@ def train(
127127
# Initialize collective group
128128
torchrl_logger.info("Initializing collective group...")
129129
metadata = get_model_metadata(policy_training)
130-
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
130+
sender.init_all_workers_group(metadata, vllm_engine=inference_engine)
131131

132132
# First weight update
133133
with timeit("update_policy_weights"):

sota-implementations/grpo/grpo-sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def train(
110110
if cfg.model.compile:
111111
loss_fn = torch.compile(loss_fn)
112112

113-
vllm_engine = inference_policy.model
113+
inference_engine = inference_policy.model
114114

115115
# Create weight sync scheme
116-
weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine)
116+
weight_sync_scheme = make_weight_sync_scheme(engine=inference_engine, cfg=cfg)
117117

118118
# Set up weight sender
119119
torchrl_logger.info("Setting up weight synchronization scheme...")
@@ -123,7 +123,7 @@ def train(
123123
# Initialize collective group
124124
torchrl_logger.info("Initializing collective group...")
125125
metadata = get_model_metadata(policy_training)
126-
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
126+
sender.init_all_workers_group(metadata, vllm_engine=inference_engine)
127127

128128
# First weight update
129129
with timeit("update_policy_weights"):

sota-implementations/grpo/grpo_utils.py

Lines changed: 146 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torchrl._utils import logger as torchrl_logger, timeit
1818
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
1919
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
20-
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
21-
from torchrl.weight_update.llm import VLLMWeightSyncScheme
20+
from torchrl.modules.llm import SGLangWrapper, TransformersWrapper, vLLMWrapper
21+
from torchrl.weight_update.llm import SGLangWeightSyncScheme, VLLMWeightSyncScheme
2222
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
2323
from transformers.tokenization_utils import PreTrainedTokenizer
2424

@@ -183,27 +183,41 @@ def get_inference_model(
183183
devices: list[int] | None = None,
184184
make_ray_worker: bool = True,
185185
tokenizer: PreTrainedTokenizer | None = None,
186-
) -> vLLMWrapper:
187-
"""Creates the vLLM-based inference model for fast generation.
186+
) -> vLLMWrapper | SGLangWrapper:
187+
"""Creates the inference model for fast generation.
188188
189-
This function initializes a vLLM model server for efficient inference and wraps
190-
it in a vLLMWrapper for policy inference. vLLM provides optimized generation
191-
with better throughput than standard HuggingFace generation.
189+
This function initializes a model server (vLLM or SGLang) for efficient inference
190+
and wraps it in the appropriate wrapper for policy inference.
192191
193192
Args:
194193
cfg (DictConfig): The hydra configuration object containing model settings.
195-
Expected to have inference_model section with vLLM-specific parameters
196-
like gpu_memory_utilization and generation settings.
194+
Expected to have inference_model section with backend-specific parameters.
195+
Set inference_model.backend to "vllm" or "sglang" to select the backend.
197196
devices (list[int], optional): The devices to use for the inference model. Default: `None`.
198197
make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
199198
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use with the inference model. Default: `None`.
200199
201200
Returns:
202-
vLLMWrapper: The wrapped vLLM model ready for inference.
201+
vLLMWrapper | SGLangWrapper: The wrapped model ready for inference.
203202
204203
Raises:
205-
AssertionError: If the vLLM server or model initialization fails
204+
AssertionError: If the server or model initialization fails
206205
"""
206+
backend = getattr(cfg.inference_model, "backend", "vllm")
207+
208+
if backend == "sglang":
209+
return _get_sglang_inference_model(cfg, devices, make_ray_worker, tokenizer)
210+
else:
211+
return _get_vllm_inference_model(cfg, devices, make_ray_worker, tokenizer)
212+
213+
214+
def _get_vllm_inference_model(
215+
cfg: DictConfig,
216+
devices: list[int] | None = None,
217+
make_ray_worker: bool = True,
218+
tokenizer: PreTrainedTokenizer | None = None,
219+
) -> vLLMWrapper:
220+
"""Creates the vLLM-based inference model."""
207221
from torchrl.modules.llm.backends.vllm import AsyncVLLM
208222

209223
num_devices = cfg.inference_model.num_devices
@@ -261,6 +275,8 @@ def get_inference_model(
261275

262276
# Handle FP32 output configuration
263277
if hasattr(cfg.inference_model, "enable_fp32_output"):
278+
import os
279+
264280
enable_fp32 = cfg.inference_model.enable_fp32_output
265281
if enable_fp32:
266282
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
@@ -326,6 +342,85 @@ def get_inference_model(
326342
return policy
327343

328344

345+
def _get_sglang_inference_model(
346+
cfg: DictConfig,
347+
devices: list[int] | None = None,
348+
make_ray_worker: bool = True,
349+
tokenizer: PreTrainedTokenizer | None = None,
350+
) -> SGLangWrapper:
351+
"""Creates the SGLang-based inference model."""
352+
from torchrl.modules.llm.backends.sglang import AsyncSGLang
353+
354+
num_devices = cfg.inference_model.num_devices
355+
if num_devices is None:
356+
sglang_devices = devices if devices is not None else [1]
357+
num_devices = len(sglang_devices)
358+
else:
359+
sglang_devices = None
360+
torchrl_logger.info(
361+
f"Creating AsyncSGLang inference model with num_devices={num_devices}, devices={sglang_devices}"
362+
)
363+
364+
model_name = cfg.model.name
365+
366+
# Build parameters for AsyncSGLang
367+
inference_params = {
368+
"model_name": model_name,
369+
"tp_size": num_devices,
370+
"mem_fraction_static": getattr(
371+
cfg.inference_model, "gpu_memory_utilization", 0.9
372+
),
373+
}
374+
375+
# Handle torch_dtype
376+
if hasattr(cfg.inference_model, "torch_dtype"):
377+
dtype_str = cfg.inference_model.torch_dtype
378+
if dtype_str is not None:
379+
if isinstance(dtype_str, str):
380+
inference_params["dtype"] = getattr(torch, dtype_str)
381+
else:
382+
inference_params["dtype"] = dtype_str
383+
384+
# Add optional SGLang parameters
385+
optional_sglang_params = [
386+
"trust_remote_code",
387+
"dp_size",
388+
]
389+
390+
for param in optional_sglang_params:
391+
if hasattr(cfg.inference_model, param):
392+
value = getattr(cfg.inference_model, param)
393+
if value is not None:
394+
inference_params[param] = value
395+
396+
inference_server = AsyncSGLang.from_pretrained(**inference_params)
397+
assert inference_server is not None
398+
399+
if tokenizer is None:
400+
from transformers import AutoTokenizer
401+
402+
tokenizer = AutoTokenizer.from_pretrained(model_name)
403+
if tokenizer.pad_token == tokenizer.eos_token:
404+
tokenizer.pad_token = "PAD"
405+
tokenizer.padding_side = "left"
406+
407+
policy = SGLangWrapper(
408+
inference_server,
409+
input_mode="history",
410+
chat_template_name="qwen",
411+
return_log_probs=not cfg.env.reasoning,
412+
tokenizer=tokenizer,
413+
pad_output=False,
414+
generate_kwargs={
415+
"max_new_tokens": cfg.inference_model.max_tokens,
416+
"temperature": cfg.inference_model.temperature,
417+
"top_p": cfg.inference_model.top_p,
418+
},
419+
)
420+
assert policy.model is not None
421+
return policy
422+
423+
329424
def get_ref_model(
330425
cfg: DictConfig,
331426
tokenizer: PreTrainedTokenizer,
@@ -557,22 +652,34 @@ def get_hf_model(
557652

558653

559654
def make_weight_sync_scheme(
560-
vllm_engine,
561-
) -> VLLMWeightSyncScheme:
562-
"""Creates a vLLM weight synchronization scheme using NCCL collectives.
655+
engine,
656+
cfg: DictConfig,
657+
) -> VLLMWeightSyncScheme | SGLangWeightSyncScheme:
658+
"""Creates a weight synchronization scheme using NCCL collectives.
563659
564660
This function creates a weight sync scheme that uses NCCL for high-performance
565-
GPU-to-GPU weight transfers from the training model to vLLM inference workers.
661+
GPU-to-GPU weight transfers from the training model to inference workers.
566662
567663
Args:
568-
vllm_engine: A vLLM engine implementing the RLvLLMEngine interface
569-
(like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM).
570-
This is typically obtained from the inference policy's model attribute.
664+
engine: An inference engine implementing the RLvLLMEngine or RLSGLangEngine
665+
interface. This is typically obtained from the inference policy's model
666+
attribute.
667+
cfg: The hydra configuration object. Used to determine the backend type.
571668
572669
Returns:
573-
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
670+
VLLMWeightSyncScheme | SGLangWeightSyncScheme: A weight sync scheme
671+
configured for the inference engine.
574672
"""
575-
# Get configuration from the vLLM engine
673+
backend = getattr(cfg.inference_model, "backend", "vllm")
674+
675+
if backend == "sglang":
676+
return _make_sglang_weight_sync_scheme(engine)
677+
else:
678+
return _make_vllm_weight_sync_scheme(engine)
679+
680+
681+
def _make_vllm_weight_sync_scheme(vllm_engine) -> VLLMWeightSyncScheme:
682+
"""Creates a vLLM weight synchronization scheme."""
576683
tp_size = vllm_engine.get_tp_size()
577684
num_replicas = getattr(vllm_engine, "num_replicas", 1)
578685
master_address = vllm_engine.get_master_address()
@@ -593,6 +700,25 @@ def make_weight_sync_scheme(
593700
)
594701

595702

703+
def _make_sglang_weight_sync_scheme(sglang_engine) -> SGLangWeightSyncScheme:
704+
"""Creates an SGLang weight synchronization scheme."""
705+
server_url = sglang_engine.server_url
706+
tp_size = sglang_engine.get_tp_size()
707+
dp_size = getattr(sglang_engine, "dp_size", 1)
708+
num_gpus = tp_size * dp_size
709+
710+
torchrl_logger.info(
711+
f"Creating SGLangWeightSyncScheme with server_url={server_url}, "
712+
f"tp_size={tp_size}, dp_size={dp_size}, num_gpus={num_gpus}"
713+
)
714+
715+
return SGLangWeightSyncScheme(
716+
server_url=server_url,
717+
num_gpus=num_gpus,
718+
strategy_name="state_dict",
719+
)
720+
721+
596722
def compute_device_allocation(cfg):
597723
"""Compute device allocations and Ray GPU config.
598724

0 commit comments

Comments
 (0)