1313
1414from .test_events import SampleBatch
1515
16+ DP_RANK = 0
17+
1618
1719@pytest .fixture
1820def random_port ():
1921 """Generate a random port number for testing"""
20- return random .randint (10000 , 60000 )
22+ return random .randint (10000 , 59900 )
2123
2224
2325@pytest .fixture
@@ -30,21 +32,23 @@ def publisher_config(random_port, request):
3032 replay_endpoint = endpoint + "-replay"
3133 else :
3234 endpoint = f"tcp://*:{ random_port } "
33- replay_endpoint = f"tcp://*:{ random_port + 1 } "
35+ replay_endpoint = f"tcp://*:{ random_port + 100 } "
3436
35- return KVEventsConfig (enable_kv_cache_events = True ,
36- publisher = "zmq" ,
37- endpoint = endpoint ,
38- replay_endpoint = replay_endpoint ,
39- buffer_steps = 100 ,
40- hwm = 1000 ,
41- topic = "test" )
37+ return KVEventsConfig (
38+ enable_kv_cache_events = True ,
39+ publisher = "zmq" ,
40+ endpoint = endpoint ,
41+ replay_endpoint = replay_endpoint ,
42+ buffer_steps = 100 ,
43+ hwm = 1000 ,
44+ topic = "test" ,
45+ )
4246
4347
4448@pytest .fixture
4549def publisher (publisher_config ):
4650 """Create and return a publisher instance"""
47- pub = EventPublisherFactory .create (publisher_config )
51+ pub = EventPublisherFactory .create (publisher_config , DP_RANK )
4852 yield pub
4953 pub .shutdown ()
5054
@@ -60,34 +64,49 @@ def subscriber(publisher_config):
6064 if replay_endpoint and replay_endpoint .startswith ("tcp://*" ):
6165 replay_endpoint = replay_endpoint .replace ("*" , "127.0.0.1" )
6266
63- sub = MockSubscriber (endpoint , replay_endpoint , publisher_config .topic )
67+ sub = MockSubscriber (
68+ [endpoint ],
69+ [replay_endpoint ] if replay_endpoint else None ,
70+ publisher_config .topic ,
71+ )
6472 yield sub
6573 sub .close ()
6674
6775
6876class MockSubscriber :
6977 """Helper class to receive and verify published events"""
7078
71- def __init__ (self ,
72- pub_endpoint : str ,
73- replay_endpoint : Optional [str ] = None ,
74- topic : str = "" ,
75- decode_type = SampleBatch ):
79+ def __init__ (
80+ self ,
81+ pub_endpoints : Union [str , list [str ]],
82+ replay_endpoints : Optional [Union [str , list [str ]]] = None ,
83+ topic : str = "" ,
84+ decode_type = SampleBatch ,
85+ ):
7686 self .ctx = zmq .Context .instance ()
7787
78- # Set up subscriber socket
79- self .sub = self .ctx .socket (zmq .SUB )
80- self .sub .setsockopt (zmq .SUBSCRIBE , topic .encode ('utf-8' ))
81- self .sub .connect (pub_endpoint )
88+ # Convert single endpoint to list for consistency
89+ if isinstance (pub_endpoints , str ):
90+ pub_endpoints = [pub_endpoints ]
91+ if isinstance (replay_endpoints , str ):
92+ replay_endpoints = [replay_endpoints ]
8293
83- # Set up replay socket if provided
84- self .replay = None
85- if replay_endpoint :
86- self .replay = self .ctx .socket (zmq .REQ )
87- self .replay .connect (replay_endpoint )
94+ # Set up subscriber socket - connect to all endpoints
95+ self .sub = self .ctx .socket (zmq .SUB )
96+ self .sub .setsockopt (zmq .SUBSCRIBE , topic .encode ("utf-8" ))
97+ for endpoint in pub_endpoints :
98+ self .sub .connect (endpoint )
99+
100+ # Set up replay sockets if provided
101+ self .replay_sockets = []
102+ if replay_endpoints :
103+ for replay_endpoint in replay_endpoints :
104+ replay = self .ctx .socket (zmq .REQ )
105+ replay .connect (replay_endpoint )
106+ self .replay_sockets .append (replay )
88107
89108 self .topic = topic
90- self .topic_bytes = topic .encode (' utf-8' )
109+ self .topic_bytes = topic .encode (" utf-8" )
91110 self .received_msgs : list [tuple [int , SampleBatch ]] = []
92111 self .last_seq = - 1
93112 self .decoder = msgspec .msgpack .Decoder (type = decode_type )
@@ -107,25 +126,31 @@ def receive_one(self,
107126 self .received_msgs .append ((seq , data ))
108127 return seq , data
109128
110- def request_replay (self , start_seq : int ) -> None :
129+ def request_replay (self , start_seq : int , socket_idx : int = 0 ) -> None :
111130 """Request replay of messages starting from start_seq"""
112- if not self .replay :
113- raise ValueError ("Replay socket not initialized" )
114-
115- self .replay .send (start_seq .to_bytes (8 , "big" ))
116-
117- def receive_replay (self ) -> list [tuple [int , SampleBatch ]]:
118- """Receive replayed messages"""
119- if not self .replay :
120- raise ValueError ("Replay socket not initialized" )
121-
131+ if not self .replay_sockets :
132+ raise ValueError ("Replay sockets not initialized" )
133+ if socket_idx >= len (self .replay_sockets ):
134+ raise ValueError (f"Invalid socket index { socket_idx } " )
135+
136+ self .replay_sockets [socket_idx ].send (start_seq .to_bytes (8 , "big" ))
137+
138+ def receive_replay (self ,
139+ socket_idx : int = 0 ) -> list [tuple [int , SampleBatch ]]:
140+ """Receive replayed messages from a specific replay socket"""
141+ if not self .replay_sockets :
142+ raise ValueError ("Replay sockets not initialized" )
143+ if socket_idx >= len (self .replay_sockets ):
144+ raise ValueError (f"Invalid socket index { socket_idx } " )
145+
146+ replay_socket = self .replay_sockets [socket_idx ]
122147 replayed : list [tuple [int , SampleBatch ]] = []
123148 while True :
124149 try :
125- if not self . replay .poll (1000 ):
150+ if not replay_socket .poll (1000 ):
126151 break
127152
128- frames = self . replay .recv_multipart ()
153+ frames = replay_socket .recv_multipart ()
129154 if not frames or not frames [- 1 ]:
130155 # End of replay marker
131156 break
@@ -142,5 +167,5 @@ def receive_replay(self) -> list[tuple[int, SampleBatch]]:
142167 def close (self ):
143168 """Clean up resources"""
144169 self .sub .close ()
145- if self .replay :
146- self . replay .close ()
170+ for replay in self .replay_sockets :
171+ replay .close ()
0 commit comments