1+ import ray
2+ import zmq
3+ import time
4+ import uuid
5+ import pytest
6+ import torch
7+ import tensordict
8+ from threading import Thread
9+ from unittest .mock import MagicMock
10+ from pathlib import Path
11+ import sys
12+ import numpy as np
13+ from concurrent .futures import ThreadPoolExecutor , as_completed
14+ from tensordict import TensorDict
15+
16+ # Import your classes here
17+ parent_dir = Path (__file__ ).resolve ().parent .parent
18+ sys .path .append (str (parent_dir ))
19+
20+ try :
21+ from transfer_queue .data_system import TransferQueueStorageSimpleUnit
22+ from transfer_queue .utils .zmq_utils import ZMQServerInfo , ZMQRequestType , ZMQMessage
23+ except ImportError :
24+ # For testing purposes if imports are not available
25+ TransferQueueStorageSimpleUnit = MagicMock ()
26+ ZMQServerInfo = MagicMock ()
27+ ZMQRequestType = MagicMock ()
28+ ZMQMessage = MagicMock ()
29+
30+
31+ # Mock ZMQ utilities if not available in test environment
32+ def create_zmq_socket (context , socket_type , identity = None ):
33+ sock = context .socket (socket_type )
34+ if identity :
35+ sock .setsockopt (zmq .IDENTITY , identity )
36+ return sock
37+
38+
39+ # Mock Controller to handle handshake and data updates
40+ class MockController :
41+ def __init__ (self , controller_id = "controller_001" ):
42+ self .controller_id = controller_id
43+ self .context = zmq .Context ()
44+
45+ # Socket for handshake
46+ self .handshake_socket = self .context .socket (zmq .ROUTER )
47+ self .handshake_port = self ._bind_to_random_port (self .handshake_socket )
48+
49+ # Socket for data status updates
50+ self .data_update_socket = self .context .socket (zmq .ROUTER )
51+ self .data_update_port = self ._bind_to_random_port (self .data_update_socket )
52+
53+ self .zmq_server_info = ZMQServerInfo .create (
54+ role = "CONTROLLER" ,
55+ id = controller_id ,
56+ ip = "127.0.0.1" ,
57+ ports = {
58+ "handshake_socket" : self .handshake_port ,
59+ "data_status_update_socket" : self .data_update_port
60+ }
61+ )
62+
63+ self .running = True
64+ self .handshake_thread = Thread (target = self ._handle_handshake , daemon = True )
65+ self .data_update_thread = Thread (target = self ._handle_data_updates , daemon = True )
66+ self .handshake_thread .start ()
67+ self .data_update_thread .start ()
68+
69+ def _bind_to_random_port (self , socket ):
70+ port = socket .bind_to_random_port ("tcp://127.0.0.1" )
71+ return port
72+
73+ def _handle_handshake (self ):
74+ poller = zmq .Poller ()
75+ poller .register (self .handshake_socket , zmq .POLLIN )
76+
77+ while self .running :
78+ try :
79+ socks = dict (poller .poll (100 )) # 100ms timeout
80+ if self .handshake_socket in socks :
81+ identity , msg_bytes = self .handshake_socket .recv_multipart ()
82+ msg = ZMQMessage .deserialize (msg_bytes )
83+
84+ # Send handshake ack
85+ ack_msg = ZMQMessage .create (
86+ request_type = ZMQRequestType .HANDSHAKE_ACK ,
87+ sender_id = self .controller_id ,
88+ body = {"message" : "Handshake successful" }
89+ )
90+ self .handshake_socket .send_multipart ([identity , ack_msg .serialize ()])
91+ except zmq .Again :
92+ continue
93+ except Exception :
94+ if self .running :
95+ pass
96+
97+ def _handle_data_updates (self ):
98+ poller = zmq .Poller ()
99+ poller .register (self .data_update_socket , zmq .POLLIN )
100+
101+ while self .running :
102+ try :
103+ socks = dict (poller .poll (100 )) # 100ms timeout
104+ if self .data_update_socket in socks :
105+ identity , msg_bytes = self .data_update_socket .recv_multipart ()
106+ msg = ZMQMessage .deserialize (msg_bytes )
107+
108+ # Send data update ack
109+ ack_msg = ZMQMessage .create (
110+ request_type = ZMQRequestType .NOTIFY_DATA_UPDATE_ACK ,
111+ sender_id = self .controller_id ,
112+ body = {"message" : "Data update received" }
113+ )
114+ self .data_update_socket .send_multipart ([identity , ack_msg .serialize ()])
115+ except zmq .Again :
116+ continue
117+ except Exception :
118+ if self .running :
119+ pass
120+
121+ def stop (self ):
122+ self .running = False
123+ time .sleep (0.1 ) # Give threads time to stop
124+ self .handshake_socket .close ()
125+ self .data_update_socket .close ()
126+
127+
128+ # Mock client to send PUT/GET requests
129+ class MockClient :
130+ def __init__ (self , storage_put_get_address ):
131+ self .context = zmq .Context ()
132+ self .socket = self .context .socket (zmq .DEALER )
133+ self .socket .setsockopt (zmq .RCVTIMEO , 5000 ) # 5 second timeout
134+ self .socket .connect (storage_put_get_address )
135+
136+ def send_put (self , client_id , global_indexes , local_indexes , field_data ):
137+ msg = ZMQMessage .create (
138+ request_type = ZMQRequestType .PUT_DATA ,
139+ sender_id = f"mock_client_{ client_id } " ,
140+ body = {
141+ "global_indexes" : global_indexes ,
142+ "local_indexes" : local_indexes ,
143+ "field_data" : field_data
144+ }
145+ )
146+ self .socket .send (msg .serialize ())
147+ return ZMQMessage .deserialize (self .socket .recv ())
148+
149+ def send_get (self , client_id , local_indexes , fields ):
150+ msg = ZMQMessage .create (
151+ request_type = ZMQRequestType .GET_DATA ,
152+ sender_id = f"mock_client_{ client_id } " ,
153+ body = {
154+ "local_indexes" : local_indexes ,
155+ "fields" : fields
156+ }
157+ )
158+ self .socket .send (msg .serialize ())
159+ return ZMQMessage .deserialize (self .socket .recv ())
160+
161+ def close (self ):
162+ self .socket .close ()
163+ self .context .term ()
164+
165+
166+ @pytest .fixture (scope = "session" )
167+ def ray_setup ():
168+ ray .init (ignore_reinit_error = True )
169+ yield
170+ ray .shutdown ()
171+
172+
173+ @pytest .fixture
174+ def storage_setup (ray_setup ):
175+ storage_size = 10000
176+ tensordict .set_list_to_stack (True ).set ()
177+
178+ # Start mock controller
179+ mock_controller = MockController (f"controller_{ uuid .uuid4 ()} " )
180+ time .sleep (0.5 ) # Wait for controller sockets to be ready
181+
182+ # Start Ray actor
183+ storage_actor = TransferQueueStorageSimpleUnit .options (max_concurrency = 50 , num_cpus = 1 ).remote (storage_size )
184+
185+ # Register controller info
186+ controller_infos = {mock_controller .controller_id : mock_controller .zmq_server_info }
187+ ray .get (storage_actor .register_controller_info .remote (controller_infos ))
188+
189+ # Get ZMQ address to connect client
190+ zmq_info = ray .get (storage_actor .get_zmq_server_info .remote ())
191+ put_get_address = zmq_info .to_addr ("put_get_socket" )
192+ time .sleep (1 ) # Wait for socket to be ready
193+
194+ yield storage_actor , put_get_address , mock_controller
195+
196+ # Cleanup
197+ mock_controller .stop ()
198+
199+
200+ def test_put_get_single_client (storage_setup ):
201+ """Test basic put and get operations with a single client using TensorDict and torch tensors."""
202+ _ , put_get_address , _ = storage_setup
203+
204+ client = MockClient (put_get_address )
205+
206+ # PUT data
207+ global_indexes = [0 , 1 , 2 ]
208+ local_indexes = [0 , 1 , 2 ]
209+ field_data = TensorDict ({
210+ "log_probs" : [torch .tensor ([1.0 , 2.0 , 3.0 ]), torch .tensor ([4.0 , 5.0 , 6.0 ]), torch .tensor ([7.0 , 8.0 , 9.0 ])],
211+ "rewards" : [torch .tensor ([10.0 ]), torch .tensor ([20.0 ]), torch .tensor ([30.0 ])]
212+ }, batch_size = [])
213+
214+ response = client .send_put (0 , global_indexes , local_indexes , field_data )
215+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE
216+
217+ # GET data
218+ response = client .send_get (0 , [0 , 1 ], ["log_probs" , "rewards" ])
219+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
220+
221+ retrieved_data = response .body ["message" ]["data" ]
222+ assert "log_probs" in retrieved_data
223+ assert "rewards" in retrieved_data
224+ assert len (retrieved_data ["log_probs" ]) == 2
225+ assert len (retrieved_data ["rewards" ]) == 2
226+
227+ # Verify data correctness
228+ torch .testing .assert_close (retrieved_data ["log_probs" ][0 ], torch .tensor ([1.0 , 2.0 , 3.0 ]))
229+ torch .testing .assert_close (retrieved_data ["log_probs" ][1 ], torch .tensor ([4.0 , 5.0 , 6.0 ]))
230+ torch .testing .assert_close (retrieved_data ["rewards" ][0 ], torch .tensor ([10.0 ]))
231+ torch .testing .assert_close (retrieved_data ["rewards" ][1 ], torch .tensor ([20.0 ]))
232+
233+ client .close ()
234+
235+
236+ def test_put_get_multiple_clients (storage_setup ):
237+ """Test put and get operations with multiple clients including overlapping local indexes"""
238+ _ , put_get_address , _ = storage_setup
239+
240+ num_clients = 5
241+ clients = [MockClient (put_get_address ) for _ in range (num_clients )]
242+
243+ # Each client puts unique data using different local_indexes
244+ for i , client in enumerate (clients ):
245+ global_indexes = [i * 10 + 0 , i * 10 + 1 , i * 10 + 2 ]
246+ local_indexes = [i * 10 + 0 , i * 10 + 1 , i * 10 + 2 ]
247+ field_data = TensorDict ({
248+ "log_probs" : [torch .tensor ([i , i + 1 , i + 2 ]), torch .tensor ([i + 3 , i + 4 , i + 5 ]),
249+ torch .tensor ([i + 6 , i + 7 , i + 8 ])],
250+ "rewards" : [torch .tensor ([i * 10 ]), torch .tensor ([i * 10 + 10 ]), torch .tensor ([i * 10 + 20 ])]
251+ })
252+
253+ response = client .send_put (i , global_indexes , local_indexes , field_data )
254+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE
255+
256+ # Now simulate a third client that writes to overlapping local_indexes (e.g., index 0)
257+ overlapping_client = MockClient (put_get_address )
258+ overlap_local_indexes = [0 ] # Overlaps with first client's index 0
259+ overlap_field_data = TensorDict ({
260+ "log_probs" : [torch .tensor ([999 , 999 , 999 ])],
261+ "rewards" : [torch .tensor ([999 ])]
262+ })
263+ response = overlapping_client .send_put (
264+ client_id = 99 ,
265+ global_indexes = [0 ],
266+ local_indexes = overlap_local_indexes ,
267+ field_data = overlap_field_data
268+ )
269+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE
270+
271+ # Each original client gets its own data (except for index 0 which was overwritten)
272+ for i , client in enumerate (clients ):
273+ response = client .send_get (i , [i * 10 + 0 , i * 10 + 1 ], ["log_probs" , "rewards" ])
274+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
275+
276+ retrieved_data = response .body ["message" ]["data" ]
277+ assert len (retrieved_data ["log_probs" ]) == 2
278+ assert len (retrieved_data ["rewards" ]) == 2
279+
280+ # For index 0, expect data from overlapping_client; others from original client
281+ if i == 0 :
282+ # Index 0 was overwritten
283+ torch .testing .assert_close (retrieved_data ["log_probs" ][0 ], torch .tensor ([999 , 999 , 999 ]))
284+ torch .testing .assert_close (retrieved_data ["rewards" ][0 ], torch .tensor ([999 ]))
285+ # Index 1 remains original
286+ torch .testing .assert_close (retrieved_data ["log_probs" ][1 ], torch .tensor ([3 , 4 , 5 ]))
287+ torch .testing .assert_close (retrieved_data ["rewards" ][1 ], torch .tensor ([10 ]))
288+ else :
289+ # All data remains original
290+ torch .testing .assert_close (retrieved_data ["log_probs" ][0 ], torch .tensor ([i , i + 1 , i + 2 ]))
291+ torch .testing .assert_close (retrieved_data ["log_probs" ][1 ], torch .tensor ([i + 3 , i + 4 , i + 5 ]))
292+ torch .testing .assert_close (retrieved_data ["rewards" ][0 ], torch .tensor ([i * 10 ]))
293+ torch .testing .assert_close (retrieved_data ["rewards" ][1 ], torch .tensor ([i * 10 + 10 ]))
294+
295+ # Cleanup
296+ for client in clients :
297+ client .close ()
298+ overlapping_client .close ()
299+
300+
301+ def test_performance_basic (storage_setup ):
302+ """Basic performance test with larger data volume and proper index handling"""
303+ _ , put_get_address , _ = storage_setup
304+
305+ client = MockClient (put_get_address )
306+
307+ # PUT performance test
308+ put_latencies = []
309+ num_puts = 50
310+ batch_size = 128
311+
312+ for i in range (num_puts ):
313+ start = time .time ()
314+
315+ # Use larger batch size and more complex index mapping
316+ global_indexes = list (range (i * batch_size , (i + 1 ) * batch_size ))
317+ local_indexes = list (range (i * batch_size , (i + 1 ) * batch_size ))
318+
319+ # Create larger tensor data to increase data volume
320+ log_probs_data = []
321+ rewards_data = []
322+
323+ for j in range (batch_size ):
324+ # Each sample contains larger tensors to increase data transfer volume
325+ log_probs_tensor = torch .randn (32768 )
326+ rewards_tensor = torch .randn (32768 )
327+ log_probs_data .append (log_probs_tensor )
328+ rewards_data .append (rewards_tensor )
329+
330+ field_data = TensorDict (
331+ {
332+ "log_probs" : log_probs_data ,
333+ "rewards" : rewards_data
334+ },
335+ batch_size = [batch_size ]
336+ )
337+
338+ response = client .send_put (0 , global_indexes , local_indexes , field_data )
339+ latency = time .time () - start
340+ put_latencies .append (latency )
341+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE
342+
343+ # GET performance test
344+ get_latencies = []
345+ num_gets = 50
346+
347+ for i in range (num_gets ):
348+ start = time .time ()
349+ # Retrieve larger batch of data
350+ indices = list (range (i * batch_size , (i + 1 ) * batch_size )) # Retrieve batch_size indices of data each time
351+ response = client .send_get (0 , indices , ["log_probs" , "rewards" ])
352+ latency = time .time () - start
353+ get_latencies .append (latency )
354+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
355+
356+ avg_put_latency = sum (put_latencies ) / len (put_latencies ) * 1000 # ms
357+ avg_get_latency = sum (get_latencies ) / len (get_latencies ) * 1000 # ms
358+
359+ # Adjust performance thresholds to accommodate larger data volume
360+ assert avg_put_latency < 5000 , f"Avg PUT latency { avg_put_latency } ms exceeds threshold"
361+ assert avg_get_latency < 5000 , f"Avg GET latency { avg_get_latency } ms exceeds threshold"
362+
363+ client .close ()
0 commit comments