Skip to content

Commit b36d10e

Browse files
authored
add test for async-sync function invocation handler (#553)
* Refactor logging level and enhance async context handling - Changed logging level from info to debug in RegularRequestHandler for improved log granularity. - Refactored `_async_inject_context` in base.py to streamline context injection for async functions. - Introduced `_handle_async_function` to manage the invocation of async and sync functions, improving code clarity. - Added tests for `_handle_async_function` and `_async_inject_context` to ensure correct behavior and context handling. * Add integration test for asynchronous API functionality - Introduced a new test file `test_async.py` to validate the behavior of an asynchronous API using the LitServe framework. - Implemented a minimal async API class and a corresponding test case to ensure correct prediction output for async requests. - Utilized pytest and httpx for testing the async functionality, ensuring proper handling of requests and responses.
1 parent 550a23e commit b36d10e

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed

src/litserve/loops/base.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,7 @@ def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
4545
return func(*args, **kwargs)
4646

4747

48-
async def _async_inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
49-
sig = inspect.signature(func)
50-
51-
# Determine if we need to inject context
52-
if "context" in sig.parameters:
53-
kwargs["context"] = context
54-
48+
async def _handle_async_function(func, *args, **kwargs):
5549
# Call the function based on its type
5650
if inspect.isasyncgenfunction(func):
5751
# Async generator - return directly (don't await)
@@ -69,6 +63,16 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, *
6963
return result
7064

7165

66+
async def _async_inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
67+
sig = inspect.signature(func)
68+
69+
# Determine if we need to inject context
70+
if "context" in sig.parameters:
71+
kwargs["context"] = context
72+
73+
return await _handle_async_function(func, *args, **kwargs)
74+
75+
7276
def collate_requests(
7377
lit_api: LitAPI,
7478
request_queue: Queue,

src/litserve/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ async def handle_request(self, request, request_type) -> Response:
300300
class RegularRequestHandler(BaseRequestHandler):
301301
async def handle_request(self, request, request_type) -> Response:
302302
try:
303-
logger.info(f"Handling request: {request}")
303+
logger.debug(f"Handling request: {request}")
304304
# Prepare request
305305
payload = await self._prepare_request(request, request_type)
306306

tests/integration/test_async.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
from asgi_lifespan import LifespanManager
3+
from httpx import ASGITransport, AsyncClient
4+
5+
import litserve as ls
6+
from litserve.utils import wrap_litserve_start
7+
8+
9+
class MinimalAsyncAPI(ls.LitAPI):
10+
def setup(self, device):
11+
self.model = None
12+
13+
async def predict(self, x):
14+
y = x["input"] ** 2
15+
return {"output": y}
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_async_api():
20+
server = ls.LitServer(MinimalAsyncAPI(enable_async=True))
21+
with wrap_litserve_start(server) as server:
22+
async with LifespanManager(server.app) as manager, AsyncClient(
23+
transport=ASGITransport(app=manager.app), base_url="http://test"
24+
) as ac:
25+
response = await ac.post("/predict", json={"input": 2})
26+
assert response.json() == {"output": 4}

tests/test_loops.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
import time
2222
from queue import Empty, Queue
23-
from typing import Dict, List, Optional
23+
from typing import AsyncGenerator, Dict, List, Optional
2424
from unittest.mock import MagicMock, patch
2525

2626
import pytest
@@ -33,7 +33,7 @@
3333
from litserve import LitAPI
3434
from litserve.callbacks import CallbackRunner
3535
from litserve.loops import BatchedStreamingLoop, LitLoop, Output, StreamingLoop, inference_worker
36-
from litserve.loops.base import DefaultLoop
36+
from litserve.loops.base import DefaultLoop, _async_inject_context, _handle_async_function
3737
from litserve.loops.continuous_batching_loop import (
3838
ContinuousBatchingLoop,
3939
notify_timed_out_requests,
@@ -918,3 +918,30 @@ async def test_continuous_batching_run(continuous_batching_setup):
918918
assert o == ""
919919
assert status == LitAPIStatus.FINISH_STREAMING
920920
assert response_type == LoopResponseType.STREAMING
921+
922+
923+
@pytest.mark.asyncio
924+
async def test_handle_async_function():
925+
async def async_func():
926+
return "async"
927+
928+
def sync_func():
929+
return "sync"
930+
931+
async def async_gen():
932+
for i in range(3):
933+
yield i
934+
935+
assert await _handle_async_function(async_func) == "async"
936+
assert await _handle_async_function(sync_func) == "sync"
937+
async_gen = await _handle_async_function(async_gen)
938+
assert isinstance(async_gen, AsyncGenerator)
939+
940+
941+
@pytest.mark.asyncio
942+
async def test_async_inject_context():
943+
async def async_func(x, context=0):
944+
return x * context["a"]
945+
946+
context = {"a": 1}
947+
assert await _async_inject_context(context, async_func, 2) == 2

0 commit comments

Comments
 (0)