55import copy
66import itertools
77import logging
8+ import random
89import threading
10+ import time
911
1012import pytest
1113import pytest_asyncio
@@ -40,7 +42,7 @@ class Engine(cognitive_engine.Engine, threading.Thread):
4042
4143 def __init__ (self , engine_id , zeromq_address , handle_method = None ):
4244 """Initialize the engine and engine runner."""
43- super ().__init__ ()
45+ super ().__init__ (daemon = True )
4446 self .engine_id = engine_id
4547 self .engine_name = f"Engine-{ engine_id } "
4648 self .zeromq_address = zeromq_address
@@ -73,6 +75,7 @@ def handle(self, input_frame):
7375
7476 def run (self ):
7577 """Run the engine runner."""
78+ logger .info (f"Running engine { self .engine_id } in a new thread" )
7679 self .engine_runner .run ()
7780
7881 async def run_async (self ):
@@ -157,6 +160,12 @@ def engine_disconnection_timeout():
157160 return 5
158161
159162
163+ @pytest .fixture
164+ def num_tokens ():
165+ """Number of tokens to use for the input producer."""
166+ return DEFAULT_NUM_TOKENS
167+
168+
160169@pytest_asyncio .fixture
161170async def run_server (
162171 server_frontend_port ,
@@ -207,12 +216,24 @@ def engine_ids():
207216 return None
208217
209218
219+ @pytest .fixture
220+ def run_engines_threaded ():
221+ """Run engines in a different thread."""
222+ return False
223+
224+
210225@pytest_asyncio .fixture
211226async def run_engines (
212- run_server , server_backend_port , num_engines , handle_method , engine_ids
227+ run_server ,
228+ server_backend_port ,
229+ num_engines ,
230+ handle_method ,
231+ engine_ids ,
232+ run_engines_threaded ,
213233):
214234 """Run engines connected to the server backend port."""
215235 engines = []
236+ engine_tasks = []
216237 logger .info (f"Running engines, connecting to { server_backend_port = } !" )
217238
218239 for i in range (num_engines ):
@@ -221,17 +242,24 @@ async def run_engines(
221242 engine = Engine (engine_ids [i ], zeromq_address , handle_method )
222243 else :
223244 engine = Engine (i , zeromq_address , handle_method )
224- task = asyncio .create_task (engine .run_async ())
225- engines .append (task )
226- task .add_done_callback (
227- lambda t : t .result () if not t .cancelled () else None
228- )
245+ engines .append (engine )
246+ if run_engines_threaded :
247+ engine .start ()
248+ # task = asyncio.create_task(asyncio.to_thread(engine.join))
249+ else :
250+ task = asyncio .create_task (engine .run_async ())
251+ task .add_done_callback (
252+ lambda t : t .result () if not t .cancelled () else None
253+ )
254+ engine_tasks .append (task )
229255
230256 yield engines
257+ if run_engines_threaded :
258+ return
231259 logger .info ("Tearing down engines" )
232- for task in engines :
260+ for task in engine_tasks :
233261 task .cancel ()
234- await asyncio .gather (* engines , return_exceptions = True )
262+ await asyncio .gather (* engine_tasks , return_exceptions = True )
235263 logger .info ("Done tearing down engines" )
236264
237265
@@ -277,6 +305,43 @@ async def producer() -> gabriel_pb2.InputFrame | None:
277305 input_producer .stop ()
278306
279307
308+ @pytest .fixture
309+ def multiple_input_producers (target_engines , num_inputs_to_send ):
310+ """Create an InputProducer that sends text frames to the server."""
311+ logger .info (f"Target engines: { target_engines } " )
312+
313+ inputs_sent = 0
314+
315+ async def producer () -> gabriel_pb2 .InputFrame | None :
316+ logger .info ("Producing input" )
317+ frame = gabriel_pb2 .InputFrame ()
318+ frame .payload_type = gabriel_pb2 .PayloadType .TEXT
319+ frame .string_payload = "Hello from client"
320+ await asyncio .sleep (0.5 )
321+
322+ nonlocal inputs_sent
323+ nonlocal num_inputs_to_send
324+ inputs_sent += 1
325+ if num_inputs_to_send > 0 and inputs_sent > num_inputs_to_send :
326+ return None
327+ logger .info (f"Inputs sent: { inputs_sent } " )
328+
329+ return frame
330+
331+ producer1 = InputProducer (
332+ producer = producer , target_engine_ids = target_engines
333+ )
334+ producer2 = InputProducer (
335+ producer = producer , target_engine_ids = target_engines
336+ )
337+ producer3 = InputProducer (
338+ producer = producer , target_engine_ids = target_engines
339+ )
340+ yield [producer1 , producer2 , producer3 ]
341+ producer1 .stop ()
342+ producer2 .stop ()
343+
344+
280345@pytest .fixture
281346def empty_frame_producer (target_engines , num_inputs_to_send ):
282347 """A producer that does not set fields in the frame it returns."""
@@ -1351,3 +1416,64 @@ async def test_zeromq_result_output(
13511416 assert result .target_engine_id == "Engine-0"
13521417 assert result .string_result == "hello"
13531418 assert result .frame_id == 1
1419+
1420+
1421+ def heterogenous_engine_handle (input_frame ):
1422+ """A handle method that sleeps different durations."""
1423+ sleep_duration = random .choice ([0.01 , 0.02 , 0.03 ])
1424+ time .sleep (sleep_duration )
1425+ logger .info (f"Slept for { sleep_duration } seconds" )
1426+ status = gabriel_pb2 .Status ()
1427+ status .code = gabriel_pb2 .StatusCode .SUCCESS
1428+
1429+ return cognitive_engine .Result (status , "hello" )
1430+
1431+
1432+ @pytest .mark .parametrize ("num_engines" , [3 ])
1433+ @pytest .mark .parametrize (
1434+ "target_engines" , [["Engine-0" , "Engine-1" , "Engine-2" ]]
1435+ )
1436+ @pytest .mark .parametrize ("run_engines_threaded" , [True ])
1437+ @pytest .mark .parametrize ("handle_method" , [heterogenous_engine_handle ])
1438+ @pytest .mark .asyncio
1439+ async def test_tokens_bug (
1440+ multiple_input_producers ,
1441+ server_frontend_port ,
1442+ target_engines ,
1443+ run_engines ,
1444+ response_state ,
1445+ prometheus_client_port ,
1446+ ):
1447+ """Test that we never exceed the token semaphore limit."""
1448+ response_state .clear ()
1449+ client1 = ZeroMQClient (
1450+ f"tcp://{ DEFAULT_SERVER_HOST } :{ server_frontend_port } " ,
1451+ multiple_input_producers ,
1452+ get_multiple_engine_consumer (response_state ),
1453+ prometheus_client_port ,
1454+ )
1455+ task1 = asyncio .create_task (client1 .launch_async ())
1456+
1457+ client2 = ZeroMQClient (
1458+ f"tcp://{ DEFAULT_SERVER_HOST } :{ server_frontend_port } " ,
1459+ multiple_input_producers ,
1460+ get_multiple_engine_consumer (response_state ),
1461+ prometheus_client_port ,
1462+ )
1463+ task2 = asyncio .create_task (client2 .launch_async ())
1464+
1465+ await asyncio .sleep (10 )
1466+
1467+ task1 .cancel ()
1468+ task2 .cancel ()
1469+ try :
1470+ logger .info ("Waiting for client tasks to cancel" )
1471+ await task1
1472+ await task2
1473+ except asyncio .CancelledError :
1474+ task = asyncio .current_task ()
1475+ if task is not None and task .cancelled ():
1476+ raise
1477+ logger .info ("Client tasks are cancelled" )
1478+
1479+ assert len (response_state ) == len (target_engines )
0 commit comments