Skip to content

Commit 2570834

Browse files
Fix on_request callback not triggering for API specs (#642)
* Fix on_request callback not triggering for API specs - Add server reference to LitSpec for callback access - Trigger on_request callback in OpenAIEmbeddingSpec and OpenAISpec endpoints - Add __getstate__ to exclude server from pickling - Add tests for callback triggering with specs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6fd40c1 commit 2570834

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

src/litserve/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,8 @@ def _register_spec_endpoints(self, lit_api: LitAPI):
10711071
specs = [lit_api.spec] if lit_api.spec else []
10721072
for spec in specs:
10731073
spec: LitSpec
1074+
# Set the server reference for callback triggering in spec endpoints
1075+
spec._server = self
10741076
# TODO check that path is not clashing
10751077
for path, endpoint, methods in spec.endpoints:
10761078
self.app.add_api_route(

src/litserve/specs/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def __init__(self):
3131
self.request_queue = None
3232
self.response_queue_id = None
3333

34+
def __getstate__(self):
35+
"""Exclude _server from pickling as it contains unpickleable objects."""
36+
state = self.__dict__.copy()
37+
state["_server"] = None
38+
return state
39+
3440
@property
3541
def stream(self):
3642
return False

src/litserve/specs/openai.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from fastapi.responses import StreamingResponse
2929
from pydantic import BaseModel, Field
3030

31+
from litserve.callbacks.base import EventTypes
3132
from litserve.constants import _DEFAULT_LIT_API_PATH
3233
from litserve.specs.base import LitSpec, _AsyncSpecWrapper
3334
from litserve.utils import LitAPIStatus, ResponseBufferItem, azip
@@ -502,6 +503,14 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks
502503
uids = [uuid.uuid4() for _ in range(request.n)]
503504
self.queues = []
504505
self.events = []
506+
507+
# Trigger callback
508+
self._server._callback_runner.trigger_event(
509+
EventTypes.ON_REQUEST.value,
510+
active_requests=self._server.active_requests,
511+
litserver=self._server,
512+
)
513+
505514
for uid in uids:
506515
request_el = request.model_copy()
507516
request_el.n = 1

src/litserve/specs/openai_embedding.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from fastapi import status as status_code
2525
from pydantic import BaseModel
2626

27+
from litserve.callbacks.base import EventTypes
2728
from litserve.constants import _DEFAULT_LIT_API_PATH
2829
from litserve.specs.base import LitSpec
2930
from litserve.utils import LitAPIStatus, ResponseBufferItem
@@ -261,6 +262,13 @@ async def embeddings_endpoint(self, request: EmbeddingRequest) -> EmbeddingRespo
261262
event = asyncio.Event()
262263
self.response_buffer[uid] = ResponseBufferItem(event=event)
263264

265+
# Trigger callback
266+
self._server._callback_runner.trigger_event(
267+
EventTypes.ON_REQUEST.value,
268+
active_requests=self._server.active_requests,
269+
litserver=self._server,
270+
)
271+
264272
self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), request.model_copy()))
265273
await event.wait()
266274

tests/unit/test_callbacks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,45 @@ async def test_request_tracker(capfd):
8080
await run_simple_request(server, 4)
8181
captured = capfd.readouterr()
8282
assert "Active requests: 4" in captured.out, f"Expected pattern not found in output: {captured.out}"
83+
84+
85+
@pytest.mark.asyncio
86+
async def test_request_tracker_with_spec(capfd):
87+
from litserve.specs.openai_embedding import OpenAIEmbeddingSpec
88+
from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI
89+
90+
lit_api = TestEmbedAPI(spec=OpenAIEmbeddingSpec())
91+
server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()])
92+
93+
with wrap_litserve_start(server) as server:
94+
async with (
95+
LifespanManager(server.app) as manager,
96+
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://test") as ac,
97+
):
98+
resp = await ac.post("/v1/embeddings", json={"input": "test", "model": "test"})
99+
assert resp.status_code == 200
100+
101+
captured = capfd.readouterr()
102+
assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}"
103+
104+
105+
@pytest.mark.asyncio
106+
async def test_request_tracker_with_openai_spec(capfd):
107+
from litserve.specs.openai import OpenAISpec
108+
from litserve.test_examples.openai_spec_example import TestAPI
109+
110+
lit_api = TestAPI(spec=OpenAISpec())
111+
server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()])
112+
113+
with wrap_litserve_start(server) as server:
114+
async with (
115+
LifespanManager(server.app) as manager,
116+
AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://test") as ac,
117+
):
118+
resp = await ac.post(
119+
"/v1/chat/completions", json={"messages": [{"role": "user", "content": "test"}], "model": "test"}
120+
)
121+
assert resp.status_code == 200
122+
123+
captured = capfd.readouterr()
124+
assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}"

0 commit comments

Comments
 (0)