Skip to content

Commit fb92bf2

Browse files
committed
Fix semaphore token bug and add test
1 parent 6a95842 commit fb92bf2

File tree

4 files changed

+150
-18
lines changed

4 files changed

+150
-18
lines changed

python-client/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "gabriel-client"
7-
version = "4.1.4"
7+
version = "4.1.6"
88
description = "Client library for the Gabriel real-time AI orchestration framework"
99
requires-python = ">=3.10"
1010

server/src/gabriel_server/network_engine/engine_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import asyncio
77
import logging
8+
import threading
89

910
import zmq
1011
import zmq.asyncio
@@ -57,7 +58,7 @@ def __init__(
5758
self.all_responses_required = all_responses_required
5859
self.timeout = timeout
5960
self.request_retries = request_retries
60-
self.running = True
61+
self.stop_event = threading.Event()
6162
self.done_event = asyncio.Event()
6263

6364
def run(self):
@@ -69,7 +70,7 @@ async def run_async(self):
6970
context = zmq.asyncio.Context()
7071

7172
try:
72-
while self.running and self.request_retries > 0:
73+
while not self.stop_event.is_set() and self.request_retries > 0:
7374
socket = context.socket(zmq.REQ)
7475
try:
7576
socket.setsockopt(zmq.LINGER, 0)
@@ -103,7 +104,7 @@ async def engine_loop(self, socket):
103104
f"{self.server_address}"
104105
)
105106

106-
while self.running:
107+
while not self.stop_event.is_set():
107108
if await socket.poll(self.timeout) == 0:
108109
logger.warning(f"{self.engine_id}: no response from server")
109110
self.request_retries -= 1
@@ -196,8 +197,7 @@ async def engine_loop(self, socket):
196197

197198
async def stop(self):
198199
"""Stops the engine runner."""
199-
self.running = False
200-
await self.done_event.wait()
200+
self.stop_event.set()
201201

202202

203203
def create_engine_result_payload(result_proto):

server/src/gabriel_server/network_engine/server_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ async def _receive_from_engine_worker_helper(self):
284284

285285
# Check if the result corresponds to the latest input that was
286286
# available for this engine from this producer
287+
287288
latest_input = producer_info.latest_input_sent_to_engine
289+
288290
# Check if this engine is the first to finish processing the latest
289291
# input. If so, it should get the next input from the queue.
290292
if (
@@ -540,10 +542,14 @@ async def send_payload(self, metadata_payload):
540542

541543
async def send_next_input(self):
542544
"""Send next input from queue."""
545+
current_producer = self._producers[0]
546+
if len(self._producers) > 1:
547+
current_producer.latest_input_sent_to_engine = None
548+
543549
for _ in range(len(self._producers)):
544-
producer_info = self._producers.popleft()
545-
self._producers.append(producer_info)
546-
metadata_payload = await producer_info.get_input_from_queue(
550+
self._producers.rotate(-1)
551+
producer = self._producers[0]
552+
metadata_payload = await producer.get_input_from_queue(
547553
self._engine_id
548554
)
549555
if metadata_payload is not None:

tests/integration/basic_test.py

Lines changed: 135 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import copy
66
import itertools
77
import logging
8+
import random
89
import threading
10+
import time
911

1012
import pytest
1113
import pytest_asyncio
@@ -40,7 +42,7 @@ class Engine(cognitive_engine.Engine, threading.Thread):
4042

4143
def __init__(self, engine_id, zeromq_address, handle_method=None):
4244
"""Initialize the engine and engine runner."""
43-
super().__init__()
45+
super().__init__(daemon=True)
4446
self.engine_id = engine_id
4547
self.engine_name = f"Engine-{engine_id}"
4648
self.zeromq_address = zeromq_address
@@ -73,6 +75,7 @@ def handle(self, input_frame):
7375

7476
def run(self):
7577
"""Run the engine runner."""
78+
logger.info(f"Running engine {self.engine_id} in a new thread")
7679
self.engine_runner.run()
7780

7881
async def run_async(self):
@@ -157,6 +160,12 @@ def engine_disconnection_timeout():
157160
return 5
158161

159162

163+
@pytest.fixture
164+
def num_tokens():
165+
"""Number of tokens to use for the input producer."""
166+
return DEFAULT_NUM_TOKENS
167+
168+
160169
@pytest_asyncio.fixture
161170
async def run_server(
162171
server_frontend_port,
@@ -207,12 +216,24 @@ def engine_ids():
207216
return None
208217

209218

219+
@pytest.fixture
220+
def run_engines_threaded():
221+
"""Run engines in a different thread."""
222+
return False
223+
224+
210225
@pytest_asyncio.fixture
211226
async def run_engines(
212-
run_server, server_backend_port, num_engines, handle_method, engine_ids
227+
run_server,
228+
server_backend_port,
229+
num_engines,
230+
handle_method,
231+
engine_ids,
232+
run_engines_threaded,
213233
):
214234
"""Run engines connected to the server backend port."""
215235
engines = []
236+
engine_tasks = []
216237
logger.info(f"Running engines, connecting to {server_backend_port=}!")
217238

218239
for i in range(num_engines):
@@ -221,17 +242,24 @@ async def run_engines(
221242
engine = Engine(engine_ids[i], zeromq_address, handle_method)
222243
else:
223244
engine = Engine(i, zeromq_address, handle_method)
224-
task = asyncio.create_task(engine.run_async())
225-
engines.append(task)
226-
task.add_done_callback(
227-
lambda t: t.result() if not t.cancelled() else None
228-
)
245+
engines.append(engine)
246+
if run_engines_threaded:
247+
engine.start()
248+
# task = asyncio.create_task(asyncio.to_thread(engine.join))
249+
else:
250+
task = asyncio.create_task(engine.run_async())
251+
task.add_done_callback(
252+
lambda t: t.result() if not t.cancelled() else None
253+
)
254+
engine_tasks.append(task)
229255

230256
yield engines
257+
if run_engines_threaded:
258+
return
231259
logger.info("Tearing down engines")
232-
for task in engines:
260+
for task in engine_tasks:
233261
task.cancel()
234-
await asyncio.gather(*engines, return_exceptions=True)
262+
await asyncio.gather(*engine_tasks, return_exceptions=True)
235263
logger.info("Done tearing down engines")
236264

237265

@@ -277,6 +305,43 @@ async def producer() -> gabriel_pb2.InputFrame | None:
277305
input_producer.stop()
278306

279307

308+
@pytest.fixture
309+
def multiple_input_producers(target_engines, num_inputs_to_send):
310+
"""Create an InputProducer that sends text frames to the server."""
311+
logger.info(f"Target engines: {target_engines}")
312+
313+
inputs_sent = 0
314+
315+
async def producer() -> gabriel_pb2.InputFrame | None:
316+
logger.info("Producing input")
317+
frame = gabriel_pb2.InputFrame()
318+
frame.payload_type = gabriel_pb2.PayloadType.TEXT
319+
frame.string_payload = "Hello from client"
320+
await asyncio.sleep(0.5)
321+
322+
nonlocal inputs_sent
323+
nonlocal num_inputs_to_send
324+
inputs_sent += 1
325+
if num_inputs_to_send > 0 and inputs_sent > num_inputs_to_send:
326+
return None
327+
logger.info(f"Inputs sent: {inputs_sent}")
328+
329+
return frame
330+
331+
producer1 = InputProducer(
332+
producer=producer, target_engine_ids=target_engines
333+
)
334+
producer2 = InputProducer(
335+
producer=producer, target_engine_ids=target_engines
336+
)
337+
producer3 = InputProducer(
338+
producer=producer, target_engine_ids=target_engines
339+
)
340+
yield [producer1, producer2, producer3]
341+
producer1.stop()
342+
producer2.stop()
343+
344+
280345
@pytest.fixture
281346
def empty_frame_producer(target_engines, num_inputs_to_send):
282347
"""A producer that does not set fields in the frame it returns."""
@@ -1351,3 +1416,64 @@ async def test_zeromq_result_output(
13511416
assert result.target_engine_id == "Engine-0"
13521417
assert result.string_result == "hello"
13531418
assert result.frame_id == 1
1419+
1420+
1421+
def heterogenous_engine_handle(input_frame):
1422+
"""A handle method that sleeps different durations."""
1423+
sleep_duration = random.choice([0.01, 0.02, 0.03])
1424+
time.sleep(sleep_duration)
1425+
logger.info(f"Slept for {sleep_duration} seconds")
1426+
status = gabriel_pb2.Status()
1427+
status.code = gabriel_pb2.StatusCode.SUCCESS
1428+
1429+
return cognitive_engine.Result(status, "hello")
1430+
1431+
1432+
@pytest.mark.parametrize("num_engines", [3])
1433+
@pytest.mark.parametrize(
1434+
"target_engines", [["Engine-0", "Engine-1", "Engine-2"]]
1435+
)
1436+
@pytest.mark.parametrize("run_engines_threaded", [True])
1437+
@pytest.mark.parametrize("handle_method", [heterogenous_engine_handle])
1438+
@pytest.mark.asyncio
1439+
async def test_tokens_bug(
1440+
multiple_input_producers,
1441+
server_frontend_port,
1442+
target_engines,
1443+
run_engines,
1444+
response_state,
1445+
prometheus_client_port,
1446+
):
1447+
"""Test that we never exceed the token semaphore limit."""
1448+
response_state.clear()
1449+
client1 = ZeroMQClient(
1450+
f"tcp://{DEFAULT_SERVER_HOST}:{server_frontend_port}",
1451+
multiple_input_producers,
1452+
get_multiple_engine_consumer(response_state),
1453+
prometheus_client_port,
1454+
)
1455+
task1 = asyncio.create_task(client1.launch_async())
1456+
1457+
client2 = ZeroMQClient(
1458+
f"tcp://{DEFAULT_SERVER_HOST}:{server_frontend_port}",
1459+
multiple_input_producers,
1460+
get_multiple_engine_consumer(response_state),
1461+
prometheus_client_port,
1462+
)
1463+
task2 = asyncio.create_task(client2.launch_async())
1464+
1465+
await asyncio.sleep(10)
1466+
1467+
task1.cancel()
1468+
task2.cancel()
1469+
try:
1470+
logger.info("Waiting for client tasks to cancel")
1471+
await task1
1472+
await task2
1473+
except asyncio.CancelledError:
1474+
task = asyncio.current_task()
1475+
if task is not None and task.cancelled():
1476+
raise
1477+
logger.info("Client tasks are cancelled")
1478+
1479+
assert len(response_state) == len(target_engines)

0 commit comments

Comments
 (0)