Skip to content

Commit 6d4bd84

Browse files
vmoenscursoragent
andcommitted
[Test] Add SGLang weight synchronization tests
Add tests for SGLang weight sync scheme components: TestSGLangWeightSyncScheme: - test_scheme_initialization: Valid parameter configuration - test_scheme_auto_port: Auto-assigned master port - test_create_transport: Transport factory method - test_create_sender: Sender factory method - test_create_receiver_returns_none: SGLang manages receivers TestSGLangWeightSender: - test_register_model: Model registration - test_get_model_metadata: Metadata extraction utility - test_update_weights_requires_init: Error on uninitialized transport - test_update_weights_requires_model: Error on missing model TestSGLangCollectiveTransport: - test_transport_initialization: Parameter validation - test_transport_device_parsing: Device string/int/torch.device - test_check_connection_before_init: Connection state - test_send_weights_requires_init: Error handling - test_init_requires_rank_zero: Only rank 0 can init Markers: pytest.mark.gpu ghstack-source-id: 51bedfb Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 51bedfb Pull-Request: #3434
1 parent 05e7d37 commit 6d4bd84

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed

test/llm/test_sglang_updaters.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Tests for SGLang weight synchronization schemes."""
6+
from __future__ import annotations
7+
8+
import argparse
9+
import gc
10+
import importlib.util
11+
12+
import pytest
13+
import torch
14+
from torchrl._utils import logger
15+
16+
# Check for dependencies
17+
_has_sglang = importlib.util.find_spec("sglang") is not None
18+
_has_transformers = importlib.util.find_spec("transformers") is not None
19+
20+
21+
@pytest.mark.gpu
22+
@pytest.mark.skipif(not _has_sglang, reason="sglang not available")
23+
@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
24+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
25+
class TestSGLangWeightSyncScheme:
26+
"""Tests for SGLangWeightSyncScheme configuration."""
27+
28+
@pytest.fixture(scope="class")
29+
def model_name(self):
30+
"""Model name for testing - small model for faster testing."""
31+
return "Qwen/Qwen2.5-0.5B"
32+
33+
def test_scheme_initialization(self):
34+
"""Test SGLangWeightSyncScheme initialization with valid parameters."""
35+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
36+
37+
scheme = SGLangWeightSyncScheme(
38+
server_url="http://localhost:30000",
39+
master_address="localhost",
40+
master_port=29500,
41+
num_gpus=1,
42+
strategy="tensordict",
43+
device=0,
44+
)
45+
46+
assert scheme.server_url == "http://localhost:30000"
47+
assert scheme.master_address == "localhost"
48+
assert scheme.master_port == 29500
49+
assert scheme.num_gpus == 1
50+
assert scheme.strategy_name == "tensordict"
51+
assert scheme.world_size == 2 # 1 trainer + 1 gpu
52+
53+
def test_scheme_auto_port(self):
54+
"""Test that master_port is auto-assigned when not provided."""
55+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
56+
57+
scheme = SGLangWeightSyncScheme(
58+
server_url="http://localhost:30000",
59+
num_gpus=2,
60+
)
61+
62+
assert scheme.master_port > 0
63+
assert scheme.master_port < 65536
64+
assert scheme.world_size == 3 # 1 trainer + 2 gpus
65+
66+
def test_create_transport(self):
67+
"""Test transport creation from scheme."""
68+
from torchrl.weight_update.llm import (
69+
SGLangCollectiveTransport,
70+
SGLangWeightSyncScheme,
71+
)
72+
73+
scheme = SGLangWeightSyncScheme(
74+
server_url="http://localhost:30000",
75+
num_gpus=1,
76+
)
77+
78+
transport = scheme.create_transport()
79+
assert isinstance(transport, SGLangCollectiveTransport)
80+
assert transport.server_url == "http://localhost:30000"
81+
assert transport.rank == 0
82+
assert transport.world_size == 2
83+
84+
def test_create_sender(self):
85+
"""Test sender creation from scheme."""
86+
from torchrl.weight_update.llm import SGLangWeightSender, SGLangWeightSyncScheme
87+
88+
scheme = SGLangWeightSyncScheme(
89+
server_url="http://localhost:30000",
90+
num_gpus=1,
91+
)
92+
93+
sender = scheme.create_sender()
94+
assert isinstance(sender, SGLangWeightSender)
95+
96+
def test_create_receiver_returns_none(self):
97+
"""Test that create_receiver returns None (SGLang manages receivers)."""
98+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
99+
100+
scheme = SGLangWeightSyncScheme(
101+
server_url="http://localhost:30000",
102+
num_gpus=1,
103+
)
104+
105+
receiver = scheme.create_receiver()
106+
assert receiver is None
107+
108+
109+
@pytest.mark.gpu
110+
@pytest.mark.skipif(not _has_sglang, reason="sglang not available")
111+
@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
112+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
113+
class TestSGLangWeightSender:
114+
"""Tests for SGLangWeightSender."""
115+
116+
@pytest.fixture(scope="class")
117+
def model_name(self):
118+
"""Model name for testing - small model for faster testing."""
119+
return "Qwen/Qwen2.5-0.5B"
120+
121+
@pytest.fixture(scope="class")
122+
def source_model(self, model_name):
123+
"""Create source model for weight extraction."""
124+
from transformers import AutoModelForCausalLM
125+
126+
model = AutoModelForCausalLM.from_pretrained(
127+
model_name,
128+
device_map="cuda:0",
129+
torch_dtype=torch.float16,
130+
)
131+
132+
yield model
133+
134+
# Cleanup
135+
try:
136+
del model
137+
except Exception as e:
138+
logger.warning(f"Error during model cleanup: {e}")
139+
finally:
140+
gc.collect()
141+
if torch.cuda.is_available():
142+
torch.cuda.empty_cache()
143+
144+
def test_register_model(self, source_model):
145+
"""Test model registration."""
146+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
147+
148+
scheme = SGLangWeightSyncScheme(
149+
server_url="http://localhost:30000",
150+
num_gpus=1,
151+
)
152+
153+
sender = scheme.create_sender()
154+
sender.register_model(source_model)
155+
156+
# Model reference should be stored
157+
assert sender._model_ref is not None
158+
assert sender._model_ref() is source_model
159+
160+
def test_get_model_metadata(self, source_model):
161+
"""Test model metadata extraction utility."""
162+
from torchrl.weight_update.llm import get_sglang_model_metadata
163+
164+
metadata = get_sglang_model_metadata(source_model)
165+
166+
assert isinstance(metadata, dict)
167+
assert len(metadata) > 0
168+
169+
# Check metadata structure
170+
for name, (dtype, shape) in metadata.items():
171+
assert isinstance(name, str)
172+
assert isinstance(dtype, torch.dtype)
173+
assert isinstance(shape, (tuple, torch.Size))
174+
175+
def test_update_weights_requires_init(self, source_model):
176+
"""Test that update_weights fails if transport not initialized."""
177+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
178+
179+
scheme = SGLangWeightSyncScheme(
180+
server_url="http://localhost:30000",
181+
num_gpus=1,
182+
)
183+
184+
sender = scheme.create_sender()
185+
sender.register_model(source_model)
186+
187+
# Should raise because transport not initialized
188+
with pytest.raises(RuntimeError, match="Transport not initialized"):
189+
sender.update_weights()
190+
191+
def test_update_weights_requires_model(self):
192+
"""Test that update_weights fails if no model registered."""
193+
from torchrl.weight_update.llm import SGLangWeightSyncScheme
194+
195+
scheme = SGLangWeightSyncScheme(
196+
server_url="http://localhost:30000",
197+
num_gpus=1,
198+
)
199+
200+
sender = scheme.create_sender()
201+
# Don't register model
202+
203+
# Mock the transport as initialized
204+
sender._transport = object() # Fake transport
205+
206+
# Should raise because no model registered
207+
with pytest.raises(RuntimeError, match="No model registered"):
208+
sender.update_weights()
209+
210+
211+
@pytest.mark.gpu
212+
@pytest.mark.skipif(not _has_sglang, reason="sglang not available")
213+
@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
214+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
215+
class TestSGLangCollectiveTransport:
216+
"""Tests for SGLangCollectiveTransport."""
217+
218+
def test_transport_initialization(self):
219+
"""Test transport initialization with valid parameters."""
220+
from torchrl.weight_update.llm import SGLangCollectiveTransport
221+
222+
transport = SGLangCollectiveTransport(
223+
server_url="http://localhost:30000",
224+
master_address="localhost",
225+
master_port=29500,
226+
rank=0,
227+
world_size=2,
228+
device=0,
229+
)
230+
231+
assert transport.server_url == "http://localhost:30000"
232+
assert transport.master_address == "localhost"
233+
assert transport.master_port == 29500
234+
assert transport.rank == 0
235+
assert transport.world_size == 2
236+
assert transport.device == 0
237+
238+
def test_transport_device_parsing(self):
239+
"""Test device specification parsing."""
240+
from torchrl.weight_update.llm import SGLangCollectiveTransport
241+
242+
# Test string device
243+
transport = SGLangCollectiveTransport(
244+
server_url="http://localhost:30000",
245+
master_address="localhost",
246+
master_port=29500,
247+
rank=0,
248+
world_size=2,
249+
device="cuda:1",
250+
)
251+
assert transport.device == 1
252+
253+
# Test torch.device
254+
transport2 = SGLangCollectiveTransport(
255+
server_url="http://localhost:30000",
256+
master_address="localhost",
257+
master_port=29500,
258+
rank=0,
259+
world_size=2,
260+
device=torch.device("cuda:2"),
261+
)
262+
assert transport2.device == 2
263+
264+
# Test None (defaults to 0)
265+
transport3 = SGLangCollectiveTransport(
266+
server_url="http://localhost:30000",
267+
master_address="localhost",
268+
master_port=29500,
269+
rank=0,
270+
world_size=2,
271+
device=None,
272+
)
273+
assert transport3.device == 0
274+
275+
def test_check_connection_before_init(self):
276+
"""Test that check_connection returns False before init."""
277+
from torchrl.weight_update.llm import SGLangCollectiveTransport
278+
279+
transport = SGLangCollectiveTransport(
280+
server_url="http://localhost:30000",
281+
master_address="localhost",
282+
master_port=29500,
283+
rank=0,
284+
world_size=2,
285+
)
286+
287+
assert transport.check_connection() is False
288+
289+
def test_send_weights_requires_init(self):
290+
"""Test that send_weights fails if not initialized."""
291+
from torchrl.weight_update.llm import SGLangCollectiveTransport
292+
293+
transport = SGLangCollectiveTransport(
294+
server_url="http://localhost:30000",
295+
master_address="localhost",
296+
master_port=29500,
297+
rank=0,
298+
world_size=2,
299+
)
300+
301+
# Should raise because comm group not initialized
302+
with pytest.raises(RuntimeError, match="Communication group not initialized"):
303+
transport.send_weights("model", {"param": torch.zeros(10)})
304+
305+
def test_init_requires_rank_zero(self):
306+
"""Test that init_all_workers_group only works for rank 0."""
307+
from torchrl.weight_update.llm import SGLangCollectiveTransport
308+
309+
transport = SGLangCollectiveTransport(
310+
server_url="http://localhost:30000",
311+
master_address="localhost",
312+
master_port=29500,
313+
rank=1, # Not rank 0
314+
world_size=2,
315+
)
316+
317+
# Should raise because not rank 0
318+
with pytest.raises(RuntimeError, match="Only rank 0"):
319+
transport.init_all_workers_group({})
320+
321+
322+
if __name__ == "__main__":
323+
args, unknown = argparse.ArgumentParser().parse_known_args()
324+
pytest.main([__file__, "--capture", "no", "--exitfirst", "-v", "-s"] + unknown)

0 commit comments

Comments
 (0)