Skip to content

Commit b712be9

Browse files
authored
feat: add data parallel rank to KVEventBatch (vllm-project#18925)
1 parent a8da78e commit b712be9

File tree

6 files changed

+362
-86
lines changed

6 files changed

+362
-86
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ steps:
145145
- examples/offline_inference/rlhf_colocate.py
146146
- tests/examples/offline_inference/data_parallel.py
147147
- tests/v1/test_async_llm_dp.py
148+
- tests/v1/engine/test_engine_core_client.py
148149
commands:
149150
# test with tp=2 and external_dp=2
150151
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@@ -154,6 +155,7 @@ steps:
154155
# test with internal dp
155156
- python3 ../examples/offline_inference/data_parallel.py
156157
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
158+
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
157159
- pytest -v -s distributed/test_utils.py
158160
- pytest -v -s compile/test_basic_correctness.py
159161
- pytest -v -s distributed/test_pynccl.py

tests/distributed/conftest.py

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313

1414
from .test_events import SampleBatch
1515

16+
DP_RANK = 0
17+
1618

1719
@pytest.fixture
1820
def random_port():
1921
"""Generate a random port number for testing"""
20-
return random.randint(10000, 60000)
22+
return random.randint(10000, 59900)
2123

2224

2325
@pytest.fixture
@@ -30,21 +32,23 @@ def publisher_config(random_port, request):
3032
replay_endpoint = endpoint + "-replay"
3133
else:
3234
endpoint = f"tcp://*:{random_port}"
33-
replay_endpoint = f"tcp://*:{random_port + 1}"
35+
replay_endpoint = f"tcp://*:{random_port + 100}"
3436

35-
return KVEventsConfig(enable_kv_cache_events=True,
36-
publisher="zmq",
37-
endpoint=endpoint,
38-
replay_endpoint=replay_endpoint,
39-
buffer_steps=100,
40-
hwm=1000,
41-
topic="test")
37+
return KVEventsConfig(
38+
enable_kv_cache_events=True,
39+
publisher="zmq",
40+
endpoint=endpoint,
41+
replay_endpoint=replay_endpoint,
42+
buffer_steps=100,
43+
hwm=1000,
44+
topic="test",
45+
)
4246

4347

4448
@pytest.fixture
4549
def publisher(publisher_config):
4650
"""Create and return a publisher instance"""
47-
pub = EventPublisherFactory.create(publisher_config)
51+
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
4852
yield pub
4953
pub.shutdown()
5054

@@ -60,34 +64,49 @@ def subscriber(publisher_config):
6064
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
6165
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
6266

63-
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
67+
sub = MockSubscriber(
68+
[endpoint],
69+
[replay_endpoint] if replay_endpoint else None,
70+
publisher_config.topic,
71+
)
6472
yield sub
6573
sub.close()
6674

6775

6876
class MockSubscriber:
6977
"""Helper class to receive and verify published events"""
7078

71-
def __init__(self,
72-
pub_endpoint: str,
73-
replay_endpoint: Optional[str] = None,
74-
topic: str = "",
75-
decode_type=SampleBatch):
79+
def __init__(
80+
self,
81+
pub_endpoints: Union[str, list[str]],
82+
replay_endpoints: Optional[Union[str, list[str]]] = None,
83+
topic: str = "",
84+
decode_type=SampleBatch,
85+
):
7686
self.ctx = zmq.Context.instance()
7787

78-
# Set up subscriber socket
79-
self.sub = self.ctx.socket(zmq.SUB)
80-
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
81-
self.sub.connect(pub_endpoint)
88+
# Convert single endpoint to list for consistency
89+
if isinstance(pub_endpoints, str):
90+
pub_endpoints = [pub_endpoints]
91+
if isinstance(replay_endpoints, str):
92+
replay_endpoints = [replay_endpoints]
8293

83-
# Set up replay socket if provided
84-
self.replay = None
85-
if replay_endpoint:
86-
self.replay = self.ctx.socket(zmq.REQ)
87-
self.replay.connect(replay_endpoint)
94+
# Set up subscriber socket - connect to all endpoints
95+
self.sub = self.ctx.socket(zmq.SUB)
96+
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
97+
for endpoint in pub_endpoints:
98+
self.sub.connect(endpoint)
99+
100+
# Set up replay sockets if provided
101+
self.replay_sockets = []
102+
if replay_endpoints:
103+
for replay_endpoint in replay_endpoints:
104+
replay = self.ctx.socket(zmq.REQ)
105+
replay.connect(replay_endpoint)
106+
self.replay_sockets.append(replay)
88107

89108
self.topic = topic
90-
self.topic_bytes = topic.encode('utf-8')
109+
self.topic_bytes = topic.encode("utf-8")
91110
self.received_msgs: list[tuple[int, SampleBatch]] = []
92111
self.last_seq = -1
93112
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
@@ -107,25 +126,31 @@ def receive_one(self,
107126
self.received_msgs.append((seq, data))
108127
return seq, data
109128

110-
def request_replay(self, start_seq: int) -> None:
129+
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
111130
"""Request replay of messages starting from start_seq"""
112-
if not self.replay:
113-
raise ValueError("Replay socket not initialized")
114-
115-
self.replay.send(start_seq.to_bytes(8, "big"))
116-
117-
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
118-
"""Receive replayed messages"""
119-
if not self.replay:
120-
raise ValueError("Replay socket not initialized")
121-
131+
if not self.replay_sockets:
132+
raise ValueError("Replay sockets not initialized")
133+
if socket_idx >= len(self.replay_sockets):
134+
raise ValueError(f"Invalid socket index {socket_idx}")
135+
136+
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
137+
138+
def receive_replay(self,
139+
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
140+
"""Receive replayed messages from a specific replay socket"""
141+
if not self.replay_sockets:
142+
raise ValueError("Replay sockets not initialized")
143+
if socket_idx >= len(self.replay_sockets):
144+
raise ValueError(f"Invalid socket index {socket_idx}")
145+
146+
replay_socket = self.replay_sockets[socket_idx]
122147
replayed: list[tuple[int, SampleBatch]] = []
123148
while True:
124149
try:
125-
if not self.replay.poll(1000):
150+
if not replay_socket.poll(1000):
126151
break
127152

128-
frames = self.replay.recv_multipart()
153+
frames = replay_socket.recv_multipart()
129154
if not frames or not frames[-1]:
130155
# End of replay marker
131156
break
@@ -142,5 +167,5 @@ def receive_replay(self) -> list[tuple[int, SampleBatch]]:
142167
def close(self):
143168
"""Clean up resources"""
144169
self.sub.close()
145-
if self.replay:
146-
self.replay.close()
170+
for replay in self.replay_sockets:
171+
replay.close()

tests/distributed/test_events.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
1010
NullEventPublisher)
1111

12+
DP_RANK = 0
13+
1214

1315
class EventSample(
1416
msgspec.Struct,
@@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
121123
publisher_config.replay_endpoint = None
122124

123125
publisher_config.topic = "foo"
124-
pub = EventPublisherFactory.create(publisher_config)
126+
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
125127

126128
from .conftest import MockSubscriber
127129
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
@@ -185,9 +187,72 @@ def publish_events():
185187

186188
def test_null_publisher():
187189
"""Test that NullEventPublisher can be used without errors"""
188-
publisher = NullEventPublisher()
190+
publisher = NullEventPublisher(DP_RANK)
189191

190192
# This should not raise any errors
191193
batch = create_test_events(5)
192194
publisher.publish(batch)
193195
publisher.shutdown()
196+
197+
198+
def test_data_parallel_rank_tagging(publisher_config):
199+
"""Test that events are properly tagged with their data parallel rank"""
200+
201+
publisher_config.topic = "foo"
202+
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
203+
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
204+
205+
# Hardcode the expected endpoints based on port offsetting behavior
206+
# Both ranks get offsets according to _offset_endpoint_port function
207+
base_endpoint = publisher_config.endpoint
208+
if "tcp://" in base_endpoint:
209+
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
210+
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
211+
expected_endpoint_1 = base_endpoint.replace(
212+
":5557", ":5558") # rank 1 gets port + 1
213+
else:
214+
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
215+
expected_endpoint_0 = base_endpoint # rank 0 gets base
216+
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
217+
218+
from .conftest import MockSubscriber
219+
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
220+
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
221+
222+
try:
223+
time.sleep(0.1) # Let publishers start up
224+
225+
# Publish events from different ranks
226+
batch_0 = create_test_events(2)
227+
batch_1 = create_test_events(3)
228+
229+
pub_0.publish(batch_0)
230+
pub_1.publish(batch_1)
231+
232+
# Receive events from rank 0
233+
result_0 = sub_0.receive_one(timeout=200)
234+
assert result_0 is not None, "No message received from rank 0"
235+
seq_0, received_0 = result_0
236+
237+
# Receive events from rank 1
238+
result_1 = sub_1.receive_one(timeout=200)
239+
assert result_1 is not None, "No message received from rank 1"
240+
seq_1, received_1 = result_1
241+
242+
# Verify DP rank tagging
243+
assert received_0.data_parallel_rank == 0, (
244+
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
245+
assert received_1.data_parallel_rank == 1, (
246+
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
247+
248+
# Verify event content is correct
249+
assert len(
250+
received_0.events) == 2, "Wrong number of events from rank 0"
251+
assert len(
252+
received_1.events) == 3, "Wrong number of events from rank 1"
253+
254+
finally:
255+
pub_0.shutdown()
256+
pub_1.shutdown()
257+
sub_0.close()
258+
sub_1.close()

0 commit comments

Comments
 (0)