1717from torchrl ._utils import logger as torchrl_logger , timeit
1818from torchrl .envs .llm import AddThinkingPrompt , GSM8KEnv , KLRewardTransform , RetrieveKL
1919from torchrl .envs .llm .datasets .ifeval import IFEvalEnv
20- from torchrl .modules .llm import TransformersWrapper , vLLMWrapper
21- from torchrl .weight_update .llm import VLLMWeightSyncScheme
20+ from torchrl .modules .llm import SGLangWrapper , TransformersWrapper , vLLMWrapper
21+ from torchrl .weight_update .llm import SGLangWeightSyncScheme , VLLMWeightSyncScheme
2222from transformers .models .auto .modeling_auto import AutoModelForCausalLM
2323from transformers .tokenization_utils import PreTrainedTokenizer
2424
@@ -183,27 +183,41 @@ def get_inference_model(
183183 devices : list [int ] | None = None ,
184184 make_ray_worker : bool = True ,
185185 tokenizer : PreTrainedTokenizer | None = None ,
186- ) -> vLLMWrapper :
187- """Creates the vLLM-based inference model for fast generation.
186+ ) -> vLLMWrapper | SGLangWrapper :
187+ """Creates the inference model for fast generation.
188188
189- This function initializes a vLLM model server for efficient inference and wraps
190- it in a vLLMWrapper for policy inference. vLLM provides optimized generation
191- with better throughput than standard HuggingFace generation.
189+ This function initializes a model server (vLLM or SGLang) for efficient inference
190+ and wraps it in the appropriate wrapper for policy inference.
192191
193192 Args:
194193 cfg (DictConfig): The hydra configuration object containing model settings.
195- Expected to have inference_model section with vLLM -specific parameters
196- like gpu_memory_utilization and generation settings .
194+ Expected to have inference_model section with backend -specific parameters.
195+ Set inference_model.backend to "vllm" or "sglang" to select the backend .
197196 devices (list[int], optional): The devices to use for the inference model. Default: `None`.
198197 make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
199198 tokenizer (PreTrainedTokenizer, optional): The tokenizer to use with the inference model. Default: `None`.
200199
201200 Returns:
202- vLLMWrapper: The wrapped vLLM model ready for inference.
201+ vLLMWrapper | SGLangWrapper : The wrapped model ready for inference.
203202
204203 Raises:
205- AssertionError: If the vLLM server or model initialization fails
204+ AssertionError: If the server or model initialization fails
206205 """
206+ backend = getattr (cfg .inference_model , "backend" , "vllm" )
207+
208+ if backend == "sglang" :
209+ return _get_sglang_inference_model (cfg , devices , make_ray_worker , tokenizer )
210+ else :
211+ return _get_vllm_inference_model (cfg , devices , make_ray_worker , tokenizer )
212+
213+
214+ def _get_vllm_inference_model (
215+ cfg : DictConfig ,
216+ devices : list [int ] | None = None ,
217+ make_ray_worker : bool = True ,
218+ tokenizer : PreTrainedTokenizer | None = None ,
219+ ) -> vLLMWrapper :
220+ """Creates the vLLM-based inference model."""
207221 from torchrl .modules .llm .backends .vllm import AsyncVLLM
208222
209223 num_devices = cfg .inference_model .num_devices
@@ -261,6 +275,8 @@ def get_inference_model(
261275
262276 # Handle FP32 output configuration
263277 if hasattr (cfg .inference_model , "enable_fp32_output" ):
278+ import os
279+
264280 enable_fp32 = cfg .inference_model .enable_fp32_output
265281 if enable_fp32 :
266282 os .environ ["VLLM_ENABLE_FP32_OUTPUT" ] = "1"
@@ -326,6 +342,85 @@ def get_inference_model(
326342 return policy
327343
328344
345+ def _get_sglang_inference_model (
346+ cfg : DictConfig ,
347+ devices : list [int ] | None = None ,
348+ make_ray_worker : bool = True ,
349+ tokenizer : PreTrainedTokenizer | None = None ,
350+ ) -> SGLangWrapper :
351+ """Creates the SGLang-based inference model."""
352+ from torchrl .modules .llm .backends .sglang import AsyncSGLang
353+
354+ num_devices = cfg .inference_model .num_devices
355+ if num_devices is None :
356+ sglang_devices = devices if devices is not None else [1 ]
357+ num_devices = len (sglang_devices )
358+ else :
359+ sglang_devices = None
360+ torchrl_logger .info (
361+ f"Creating AsyncSGLang inference model with num_devices={ num_devices } , devices={ sglang_devices } "
362+ )
363+
364+ model_name = cfg .model .name
365+
366+ # Build parameters for AsyncSGLang
367+ inference_params = {
368+ "model_name" : model_name ,
369+ "tp_size" : num_devices ,
370+ "mem_fraction_static" : getattr (
371+ cfg .inference_model , "gpu_memory_utilization" , 0.9
372+ ),
373+ }
374+
375+ # Handle torch_dtype
376+ if hasattr (cfg .inference_model , "torch_dtype" ):
377+ dtype_str = cfg .inference_model .torch_dtype
378+ if dtype_str is not None :
379+ if isinstance (dtype_str , str ):
380+ inference_params ["dtype" ] = getattr (torch , dtype_str )
381+ else :
382+ inference_params ["dtype" ] = dtype_str
383+
384+ # Add optional SGLang parameters
385+ optional_sglang_params = [
386+ "trust_remote_code" ,
387+ "dp_size" ,
388+ ]
389+
390+ for param in optional_sglang_params :
391+ if hasattr (cfg .inference_model , param ):
392+ value = getattr (cfg .inference_model , param )
393+ if value is not None :
394+ inference_params [param ] = value
395+
396+ inference_server = AsyncSGLang .from_pretrained (** inference_params )
397+ assert inference_server is not None
398+
399+ if tokenizer is None :
400+ from transformers import AutoTokenizer
401+
402+ tokenizer = AutoTokenizer .from_pretrained (model_name )
403+ if tokenizer .pad_token == tokenizer .eos_token :
404+ tokenizer .pad_token = "PAD"
405+ tokenizer .padding_side = "left"
406+
407+ policy = SGLangWrapper (
408+ inference_server ,
409+ input_mode = "history" ,
410+ chat_template_name = "qwen" ,
411+ return_log_probs = not cfg .env .reasoning ,
412+ tokenizer = tokenizer ,
413+ pad_output = False ,
414+ generate_kwargs = {
415+ "max_new_tokens" : cfg .inference_model .max_tokens ,
416+ "temperature" : cfg .inference_model .temperature ,
417+ "top_p" : cfg .inference_model .top_p ,
418+ },
419+ )
420+ assert policy .model is not None
421+ return policy
422+
423+
329424def get_ref_model (
330425 cfg : DictConfig ,
331426 tokenizer : PreTrainedTokenizer ,
@@ -557,22 +652,34 @@ def get_hf_model(
557652
558653
559654def make_weight_sync_scheme (
560- vllm_engine ,
561- ) -> VLLMWeightSyncScheme :
562- """Creates a vLLM weight synchronization scheme using NCCL collectives.
655+ engine ,
656+ cfg : DictConfig ,
657+ ) -> VLLMWeightSyncScheme | SGLangWeightSyncScheme :
658+ """Creates a weight synchronization scheme using NCCL collectives.
563659
564660 This function creates a weight sync scheme that uses NCCL for high-performance
565- GPU-to-GPU weight transfers from the training model to vLLM inference workers.
661+ GPU-to-GPU weight transfers from the training model to inference workers.
566662
567663 Args:
568- vllm_engine: A vLLM engine implementing the RLvLLMEngine interface
569- (like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM).
570- This is typically obtained from the inference policy's model attribute.
664+ engine: An inference engine implementing the RLvLLMEngine or RLSGLangEngine
665+ interface. This is typically obtained from the inference policy's model
666+ attribute.
667+ cfg: The hydra configuration object. Used to determine the backend type.
571668
572669 Returns:
573- VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
670+ VLLMWeightSyncScheme | SGLangWeightSyncScheme: A weight sync scheme
671+ configured for the inference engine.
574672 """
575- # Get configuration from the vLLM engine
673+ backend = getattr (cfg .inference_model , "backend" , "vllm" )
674+
675+ if backend == "sglang" :
676+ return _make_sglang_weight_sync_scheme (engine )
677+ else :
678+ return _make_vllm_weight_sync_scheme (engine )
679+
680+
681+ def _make_vllm_weight_sync_scheme (vllm_engine ) -> VLLMWeightSyncScheme :
682+ """Creates a vLLM weight synchronization scheme."""
576683 tp_size = vllm_engine .get_tp_size ()
577684 num_replicas = getattr (vllm_engine , "num_replicas" , 1 )
578685 master_address = vllm_engine .get_master_address ()
@@ -593,6 +700,25 @@ def make_weight_sync_scheme(
593700 )
594701
595702
703+ def _make_sglang_weight_sync_scheme (sglang_engine ) -> SGLangWeightSyncScheme :
704+ """Creates an SGLang weight synchronization scheme."""
705+ server_url = sglang_engine .server_url
706+ tp_size = sglang_engine .get_tp_size ()
707+ dp_size = getattr (sglang_engine , "dp_size" , 1 )
708+ num_gpus = tp_size * dp_size
709+
710+ torchrl_logger .info (
711+ f"Creating SGLangWeightSyncScheme with server_url={ server_url } , "
712+ f"tp_size={ tp_size } , dp_size={ dp_size } , num_gpus={ num_gpus } "
713+ )
714+
715+ return SGLangWeightSyncScheme (
716+ server_url = server_url ,
717+ num_gpus = num_gpus ,
718+ strategy_name = "state_dict" ,
719+ )
720+
721+
596722def compute_device_allocation (cfg ):
597723 """Compute device allocations and Ray GPU config.
598724
0 commit comments