Skip to content

Commit 85d64c8

Browse files
committed
Update
[ghstack-poisoned]
1 parent efec57b commit 85d64c8

File tree

3 files changed

+422
-0
lines changed

3 files changed

+422
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""SGLang backends for TorchRL.
7+
8+
This module provides comprehensive SGLang integration including:
9+
- Base classes and interfaces
10+
- Asynchronous SGLang server services
11+
- Shared utilities
12+
13+
Examples:
14+
>>> # Connect to an existing SGLang server
15+
>>> from torchrl.modules.llm.backends.sglang import AsyncSGLang
16+
>>> service = AsyncSGLang.connect("http://localhost:30000")
17+
18+
>>> # Launch a managed SGLang server
19+
>>> from torchrl.modules.llm.backends.sglang import AsyncSGLang
20+
>>> service = AsyncSGLang.from_pretrained("Qwen/Qwen2.5-3B")
21+
22+
>>> # All engines implement the same interface
23+
>>> from torchrl.modules.llm.backends.sglang import RLSGLangEngine
24+
"""
25+
26+
from __future__ import annotations
27+
28+
from typing import Any
29+
30+
__all__ = [
31+
# Base classes and interfaces
32+
"RLSGLangEngine",
33+
# Asynchronous SGLang
34+
"AsyncSGLang",
35+
# Utilities
36+
"get_open_port",
37+
"wait_for_server",
38+
]
39+
40+
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
41+
# Base
42+
"RLSGLangEngine": ("torchrl.modules.llm.backends.sglang.base", "RLSGLangEngine"),
43+
# Async
44+
"AsyncSGLang": (
45+
"torchrl.modules.llm.backends.sglang.sglang_server",
46+
"AsyncSGLang",
47+
),
48+
# Utils
49+
"get_open_port": (
50+
"torchrl.modules.llm.backends.sglang.sglang_utils",
51+
"get_open_port",
52+
),
53+
"wait_for_server": (
54+
"torchrl.modules.llm.backends.sglang.sglang_utils",
55+
"wait_for_server",
56+
),
57+
}
58+
59+
60+
def __getattr__(name: str) -> Any: # noqa: ANN401
61+
target = _LAZY_ATTRS.get(name)
62+
if target is None:
63+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
64+
module_name, attr_name = target
65+
module = __import__(module_name, fromlist=[attr_name])
66+
return getattr(module, attr_name)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Base classes for TorchRL SGLang backends."""
7+
8+
from __future__ import annotations
9+
10+
import abc
11+
from collections.abc import Iterator
12+
13+
import torch
14+
15+
16+
class RLSGLangEngine(abc.ABC):
17+
"""Abstract base class for TorchRL SGLang engines that support weight updates.
18+
19+
All TorchRL SGLang engines should inherit from this class and implement
20+
the required methods for weight synchronization.
21+
22+
The SGLang backend uses HTTP-based communication with the SGLang server
23+
for generation, and NCCL for weight synchronization in RL training workflows.
24+
25+
Example:
26+
>>> # All SGLang engines implement the same interface
27+
>>> class MySGLangEngine(RLSGLangEngine):
28+
... def get_tp_size(self) -> int:
29+
... return self._tp_size
30+
...
31+
... def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
32+
... return self._model_metadata
33+
...
34+
... # ... implement other abstract methods
35+
"""
36+
37+
@abc.abstractmethod
38+
def get_tp_size(self) -> int:
39+
"""Get the tensor parallel size for this engine.
40+
41+
Returns:
42+
int: Tensor parallel size
43+
"""
44+
45+
@abc.abstractmethod
46+
def get_dp_size(self) -> int:
47+
"""Get the data parallel size for this engine.
48+
49+
Returns:
50+
int: Data parallel size
51+
"""
52+
53+
@abc.abstractmethod
54+
def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
55+
"""Get model parameter metadata.
56+
57+
Returns:
58+
dict: Mapping of parameter names to (dtype, shape) tuples
59+
"""
60+
61+
@abc.abstractmethod
62+
def get_master_address(self) -> str:
63+
"""Get the master address for weight synchronization.
64+
65+
Returns:
66+
str: Master address (e.g., "localhost")
67+
"""
68+
69+
@abc.abstractmethod
70+
def get_master_port(self) -> int:
71+
"""Get the master port for weight synchronization.
72+
73+
Returns:
74+
int: Master port number
75+
"""
76+
77+
@abc.abstractmethod
78+
def init_weight_update_group(
79+
self,
80+
master_address: str | None = None,
81+
master_port: int | None = None,
82+
) -> None:
83+
"""Initialize the weight update communication group.
84+
85+
This should set up NCCL communication for weight broadcasting
86+
via the SGLang server's /init_weights_update_group API.
87+
88+
Args:
89+
master_address: Override for master address. If None, uses default.
90+
master_port: Override for master port. If None, uses default.
91+
"""
92+
93+
@abc.abstractmethod
94+
def update_weights_from_distributed(
95+
self,
96+
name: str,
97+
dtype: torch.dtype,
98+
shape: tuple[int, ...],
99+
) -> None:
100+
"""Signal the server to receive a weight update via NCCL broadcast.
101+
102+
This coordinates with the SGLang server's /update_weights_from_distributed API
103+
to receive a single weight tensor broadcasted from the trainer.
104+
105+
Args:
106+
name: Name of the parameter to update
107+
dtype: Data type of the tensor
108+
shape: Shape of the tensor
109+
"""
110+
111+
@abc.abstractmethod
112+
def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
113+
"""Update model weights from an iterator.
114+
115+
This method should handle the actual weight broadcasting/updating
116+
using NCCL communication.
117+
118+
Args:
119+
weights: Iterator yielding (parameter_name, tensor) tuples
120+
"""

0 commit comments

Comments
 (0)