Skip to content

Commit 84daa96

Browse files
authored
Merge pull request #8 from TransferQueue/dev
Dev
2 parents 4634c16 + 2953a56 commit 84daa96

File tree

2 files changed

+369
-0
lines changed

2 files changed

+369
-0
lines changed

tests/test_serial_utils_on_cpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import sys
2+
from pathlib import Path
13
import pytest
24
import torch
35
import tensordict
46
import numpy as np
57
from tensordict import NonTensorData, NonTensorStack, TensorDict
8+
9+
# Import your classes here
10+
parent_dir = Path(__file__).resolve().parent.parent
11+
sys.path.append(str(parent_dir))
612
from transfer_queue.utils.serial_utils import MsgpackEncoder, MsgpackDecoder
713

814

tests/test_simple_storage_unit.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import ray
2+
import zmq
3+
import time
4+
import uuid
5+
import pytest
6+
import torch
7+
import tensordict
8+
from threading import Thread
9+
from unittest.mock import MagicMock
10+
from pathlib import Path
11+
import sys
12+
import numpy as np
13+
from concurrent.futures import ThreadPoolExecutor, as_completed
14+
from tensordict import TensorDict
15+
16+
# Import your classes here
17+
parent_dir = Path(__file__).resolve().parent.parent
18+
sys.path.append(str(parent_dir))
19+
20+
try:
21+
from transfer_queue.data_system import TransferQueueStorageSimpleUnit
22+
from transfer_queue.utils.zmq_utils import ZMQServerInfo, ZMQRequestType, ZMQMessage
23+
except ImportError:
24+
# For testing purposes if imports are not available
25+
TransferQueueStorageSimpleUnit = MagicMock()
26+
ZMQServerInfo = MagicMock()
27+
ZMQRequestType = MagicMock()
28+
ZMQMessage = MagicMock()
29+
30+
31+
# Mock ZMQ utilities if not available in test environment
32+
def create_zmq_socket(context, socket_type, identity=None):
33+
sock = context.socket(socket_type)
34+
if identity:
35+
sock.setsockopt(zmq.IDENTITY, identity)
36+
return sock
37+
38+
39+
# Mock Controller to handle handshake and data updates
40+
class MockController:
41+
def __init__(self, controller_id="controller_001"):
42+
self.controller_id = controller_id
43+
self.context = zmq.Context()
44+
45+
# Socket for handshake
46+
self.handshake_socket = self.context.socket(zmq.ROUTER)
47+
self.handshake_port = self._bind_to_random_port(self.handshake_socket)
48+
49+
# Socket for data status updates
50+
self.data_update_socket = self.context.socket(zmq.ROUTER)
51+
self.data_update_port = self._bind_to_random_port(self.data_update_socket)
52+
53+
self.zmq_server_info = ZMQServerInfo.create(
54+
role="CONTROLLER",
55+
id=controller_id,
56+
ip="127.0.0.1",
57+
ports={
58+
"handshake_socket": self.handshake_port,
59+
"data_status_update_socket": self.data_update_port
60+
}
61+
)
62+
63+
self.running = True
64+
self.handshake_thread = Thread(target=self._handle_handshake, daemon=True)
65+
self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True)
66+
self.handshake_thread.start()
67+
self.data_update_thread.start()
68+
69+
def _bind_to_random_port(self, socket):
70+
port = socket.bind_to_random_port("tcp://127.0.0.1")
71+
return port
72+
73+
def _handle_handshake(self):
74+
poller = zmq.Poller()
75+
poller.register(self.handshake_socket, zmq.POLLIN)
76+
77+
while self.running:
78+
try:
79+
socks = dict(poller.poll(100)) # 100ms timeout
80+
if self.handshake_socket in socks:
81+
identity, msg_bytes = self.handshake_socket.recv_multipart()
82+
msg = ZMQMessage.deserialize(msg_bytes)
83+
84+
# Send handshake ack
85+
ack_msg = ZMQMessage.create(
86+
request_type=ZMQRequestType.HANDSHAKE_ACK,
87+
sender_id=self.controller_id,
88+
body={"message": "Handshake successful"}
89+
)
90+
self.handshake_socket.send_multipart([identity, ack_msg.serialize()])
91+
except zmq.Again:
92+
continue
93+
except Exception:
94+
if self.running:
95+
pass
96+
97+
def _handle_data_updates(self):
98+
poller = zmq.Poller()
99+
poller.register(self.data_update_socket, zmq.POLLIN)
100+
101+
while self.running:
102+
try:
103+
socks = dict(poller.poll(100)) # 100ms timeout
104+
if self.data_update_socket in socks:
105+
identity, msg_bytes = self.data_update_socket.recv_multipart()
106+
msg = ZMQMessage.deserialize(msg_bytes)
107+
108+
# Send data update ack
109+
ack_msg = ZMQMessage.create(
110+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
111+
sender_id=self.controller_id,
112+
body={"message": "Data update received"}
113+
)
114+
self.data_update_socket.send_multipart([identity, ack_msg.serialize()])
115+
except zmq.Again:
116+
continue
117+
except Exception:
118+
if self.running:
119+
pass
120+
121+
def stop(self):
122+
self.running = False
123+
time.sleep(0.1) # Give threads time to stop
124+
self.handshake_socket.close()
125+
self.data_update_socket.close()
126+
127+
128+
# Mock client to send PUT/GET requests
129+
class MockClient:
130+
def __init__(self, storage_put_get_address):
131+
self.context = zmq.Context()
132+
self.socket = self.context.socket(zmq.DEALER)
133+
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
134+
self.socket.connect(storage_put_get_address)
135+
136+
def send_put(self, client_id, global_indexes, local_indexes, field_data):
137+
msg = ZMQMessage.create(
138+
request_type=ZMQRequestType.PUT_DATA,
139+
sender_id=f"mock_client_{client_id}",
140+
body={
141+
"global_indexes": global_indexes,
142+
"local_indexes": local_indexes,
143+
"field_data": field_data
144+
}
145+
)
146+
self.socket.send(msg.serialize())
147+
return ZMQMessage.deserialize(self.socket.recv())
148+
149+
def send_get(self, client_id, local_indexes, fields):
150+
msg = ZMQMessage.create(
151+
request_type=ZMQRequestType.GET_DATA,
152+
sender_id=f"mock_client_{client_id}",
153+
body={
154+
"local_indexes": local_indexes,
155+
"fields": fields
156+
}
157+
)
158+
self.socket.send(msg.serialize())
159+
return ZMQMessage.deserialize(self.socket.recv())
160+
161+
def close(self):
162+
self.socket.close()
163+
self.context.term()
164+
165+
166+
@pytest.fixture(scope="session")
167+
def ray_setup():
168+
ray.init(ignore_reinit_error=True)
169+
yield
170+
ray.shutdown()
171+
172+
173+
@pytest.fixture
174+
def storage_setup(ray_setup):
175+
storage_size = 10000
176+
tensordict.set_list_to_stack(True).set()
177+
178+
# Start mock controller
179+
mock_controller = MockController(f"controller_{uuid.uuid4()}")
180+
time.sleep(0.5) # Wait for controller sockets to be ready
181+
182+
# Start Ray actor
183+
storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size)
184+
185+
# Register controller info
186+
controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info}
187+
ray.get(storage_actor.register_controller_info.remote(controller_infos))
188+
189+
# Get ZMQ address to connect client
190+
zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
191+
put_get_address = zmq_info.to_addr("put_get_socket")
192+
time.sleep(1) # Wait for socket to be ready
193+
194+
yield storage_actor, put_get_address, mock_controller
195+
196+
# Cleanup
197+
mock_controller.stop()
198+
199+
200+
def test_put_get_single_client(storage_setup):
201+
"""Test basic put and get operations with a single client using TensorDict and torch tensors."""
202+
_, put_get_address, _ = storage_setup
203+
204+
client = MockClient(put_get_address)
205+
206+
# PUT data
207+
global_indexes = [0, 1, 2]
208+
local_indexes = [0, 1, 2]
209+
field_data = TensorDict({
210+
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
211+
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])]
212+
}, batch_size=[])
213+
214+
response = client.send_put(0, global_indexes, local_indexes, field_data)
215+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
216+
217+
# GET data
218+
response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
219+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
220+
221+
retrieved_data = response.body["message"]["data"]
222+
assert "log_probs" in retrieved_data
223+
assert "rewards" in retrieved_data
224+
assert len(retrieved_data["log_probs"]) == 2
225+
assert len(retrieved_data["rewards"]) == 2
226+
227+
# Verify data correctness
228+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
229+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
230+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
231+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
232+
233+
client.close()
234+
235+
236+
def test_put_get_multiple_clients(storage_setup):
237+
"""Test put and get operations with multiple clients including overlapping local indexes"""
238+
_, put_get_address, _ = storage_setup
239+
240+
num_clients = 5
241+
clients = [MockClient(put_get_address) for _ in range(num_clients)]
242+
243+
# Each client puts unique data using different local_indexes
244+
for i, client in enumerate(clients):
245+
global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
246+
local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
247+
field_data = TensorDict({
248+
"log_probs": [torch.tensor([i, i + 1, i + 2]), torch.tensor([i + 3, i + 4, i + 5]),
249+
torch.tensor([i + 6, i + 7, i + 8])],
250+
"rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])]
251+
})
252+
253+
response = client.send_put(i, global_indexes, local_indexes, field_data)
254+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
255+
256+
# Now simulate a third client that writes to overlapping local_indexes (e.g., index 0)
257+
overlapping_client = MockClient(put_get_address)
258+
overlap_local_indexes = [0] # Overlaps with first client's index 0
259+
overlap_field_data = TensorDict({
260+
"log_probs": [torch.tensor([999, 999, 999])],
261+
"rewards": [torch.tensor([999])]
262+
})
263+
response = overlapping_client.send_put(
264+
client_id=99,
265+
global_indexes=[0],
266+
local_indexes=overlap_local_indexes,
267+
field_data=overlap_field_data
268+
)
269+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
270+
271+
# Each original client gets its own data (except for index 0 which was overwritten)
272+
for i, client in enumerate(clients):
273+
response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
274+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
275+
276+
retrieved_data = response.body["message"]["data"]
277+
assert len(retrieved_data["log_probs"]) == 2
278+
assert len(retrieved_data["rewards"]) == 2
279+
280+
# For index 0, expect data from overlapping_client; others from original client
281+
if i == 0:
282+
# Index 0 was overwritten
283+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
284+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
285+
# Index 1 remains original
286+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
287+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
288+
else:
289+
# All data remains original
290+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
291+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
292+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
293+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
294+
295+
# Cleanup
296+
for client in clients:
297+
client.close()
298+
overlapping_client.close()
299+
300+
301+
def test_performance_basic(storage_setup):
302+
"""Basic performance test with larger data volume and proper index handling"""
303+
_, put_get_address, _ = storage_setup
304+
305+
client = MockClient(put_get_address)
306+
307+
# PUT performance test
308+
put_latencies = []
309+
num_puts = 50
310+
batch_size = 128
311+
312+
for i in range(num_puts):
313+
start = time.time()
314+
315+
# Use larger batch size and more complex index mapping
316+
global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
317+
local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
318+
319+
# Create larger tensor data to increase data volume
320+
log_probs_data = []
321+
rewards_data = []
322+
323+
for j in range(batch_size):
324+
# Each sample contains larger tensors to increase data transfer volume
325+
log_probs_tensor = torch.randn(32768)
326+
rewards_tensor = torch.randn(32768)
327+
log_probs_data.append(log_probs_tensor)
328+
rewards_data.append(rewards_tensor)
329+
330+
field_data = TensorDict(
331+
{
332+
"log_probs": log_probs_data,
333+
"rewards": rewards_data
334+
},
335+
batch_size=[batch_size]
336+
)
337+
338+
response = client.send_put(0, global_indexes, local_indexes, field_data)
339+
latency = time.time() - start
340+
put_latencies.append(latency)
341+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
342+
343+
# GET performance test
344+
get_latencies = []
345+
num_gets = 50
346+
347+
for i in range(num_gets):
348+
start = time.time()
349+
# Retrieve larger batch of data
350+
indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time
351+
response = client.send_get(0, indices, ["log_probs", "rewards"])
352+
latency = time.time() - start
353+
get_latencies.append(latency)
354+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
355+
356+
avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms
357+
avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms
358+
359+
# Adjust performance thresholds to accommodate larger data volume
360+
assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
361+
assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
362+
363+
client.close()

0 commit comments

Comments
 (0)