Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from litserve.callbacks import Callback
from litserve.loggers import Logger
from litserve.server import LitServer, Request, Response
from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec
from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec, WebSocketSpec
from litserve.utils import configure_logging

configure_logging()
Expand All @@ -29,6 +29,7 @@
"Response",
"OpenAISpec",
"OpenAIEmbeddingSpec",
"WebSocketSpec",
"test_examples",
"Callback",
"Logger",
Expand Down
9 changes: 9 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,20 @@ async def stream_predict(request: self.request_type) -> self.response_type:
for spec in self._specs:
spec: LitSpec
# TODO check that path is not clashing
# add http endpoints
for path, endpoint, methods in spec.endpoints:
self.app.add_api_route(
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
)

# add websocket endpoints
for path, endpoint in spec.ws_endpoints:
self.app.add_api_websocket_route(
path,
endpoint,
dependencies=[Depends(self.setup_auth())],
)

for middleware in self.middlewares:
if isinstance(middleware, tuple):
middleware, kwargs = middleware
Expand Down
3 changes: 2 additions & 1 deletion src/litserve/specs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from litserve.specs.openai import OpenAISpec
from litserve.specs.openai_embedding import OpenAIEmbeddingSpec
from litserve.specs.websocket import WebSocketSpec

__all__ = ["OpenAISpec", "OpenAIEmbeddingSpec"]
__all__ = ["OpenAISpec", "OpenAIEmbeddingSpec", "WebSocketSpec"]
9 changes: 9 additions & 0 deletions src/litserve/specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LitSpec:

def __init__(self):
self._endpoints = []
self._ws_endpoints = []

self._server: LitServer = None

Expand All @@ -40,10 +41,18 @@ def add_endpoint(self, path: str, endpoint: Callable, methods: List[str]):
"""Register an endpoint in the spec."""
self._endpoints.append((path, endpoint, methods))

def add_ws_endpoint(self, path: str, endpoint: Callable):
"""Register a websocket endpoint in the spec."""
self._ws_endpoints.append((path, endpoint))

@property
def endpoints(self):
return self._endpoints.copy()

@property
def ws_endpoints(self):
return self._ws_endpoints.copy()

@abstractmethod
def decode_request(self, request, meta_kwargs):
"""Convert the request payload to your model input."""
Expand Down
94 changes: 94 additions & 0 deletions src/litserve/specs/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, Dict, Optional

from fastapi import WebSocket, WebSocketDisconnect, status

from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus

if TYPE_CHECKING:
from litserve import LitServer

logger = logging.getLogger(__name__)


class WebSocketSpec(LitSpec):
def __init__(self, api_path: str = "/predict"):
super().__init__()

# register the websocket endpoint
self.add_ws_endpoint(api_path, self.ws_predict)

def setup(self, server: "LitServer"):
super().setup(server)

print("WebSocket Spec is ready.")

def decode_request(self, request: Dict, context_kwargs: Optional[dict] = None) -> Any:
return request

def encode_response(self, output: Any, context_kwargs: Optional[dict] = None) -> Dict[str, Any]:
return output

async def ws_predict(self, websocket: WebSocket):
# TODO: Determine if a dedicated connection manager is needed to effectively maintain active connections
await websocket.accept()
response_queue_id = self.response_queue_id
logger.debug("Received WebSocket connection: %s", websocket.client)
try:
while True:
# TODO: Discuss support for additional payload formats beyond JSON.
payload = await websocket.receive_json()

uid = uuid.uuid4()
event = asyncio.Event()
self._server.response_buffer[uid] = event
# Send request to inference worker
self._server.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload))
# Wait for the response
await event.wait()
response, response_status = self._server.response_buffer.pop(uid)

# Handle errors
if response_status == LitAPIStatus.ERROR:
logger.error("Error in WebSocket communication: %s", response)
raise Exception("Error in WebSocket communication")

logger.debug(response)

if not isinstance(response, dict):
raise ValueError(
f"Expected response to be a dictionary, but got type {type(response)}.",
"The response should be a dictionary to ensure proper compatibility with the WebSocketSpec.",
"Please ensure that your response is a dictionary.",
)

# Send successful response back to client
await websocket.send_json(response)

except WebSocketDisconnect:
logger.debug("WebSocket client disconnected")
except Exception as e:
logger.exception("Error in WebSocket communication", exc_info=e)
# TODO: Catch unsupported payload formats and send error message
await websocket.send_json({"error": "Internal server error", "details": str(e)})
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
finally:
await websocket.close()
logger.debug("WebSocket connection closed")
15 changes: 15 additions & 0 deletions src/litserve/test_examples/websocket_spec_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from litserve import LitAPI


class WebSocketLitAPI(LitAPI):
def setup(self, device):
self.model = lambda x: f"Processed: {x}"

def decode_request(self, request):
return request.get("input", "default_input")

def predict(self, x):
return self.model(x)

def encode_response(self, output):
return {"output": output}
7 changes: 7 additions & 0 deletions tests/e2e/default_websocket_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import litserve as ls
from litserve import WebSocketSpec
from litserve.test_examples.websocket_spec_example import WebSocketLitAPI

if __name__ == "__main__":
server = ls.LitServer(WebSocketLitAPI(), spec=WebSocketSpec())
server.run()
22 changes: 22 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import psutil
import requests
from openai import OpenAI
from websockets.sync.client import connect as websocket_connect


def e2e_from_file(filename):
Expand Down Expand Up @@ -379,3 +380,24 @@ def test_openai_embedding_parity():
assert len(response.data) == 2, f"Expected 2 embeddings but got {len(response.data)}"
for data in response.data:
assert len(data.embedding) == 768, f"Expected 768 dimensions but got {len(data.embedding)}"


@e2e_from_file("tests/e2e/default_websocket_spec.py")
def test_websocket_parity():
with websocket_connect("ws://127.0.0.1:8000/predict") as websocket:
# Send a JSON payload
websocket.send(json.dumps({"input": "test_input"}))
response = websocket.recv()
response = json.loads(response)
assert response["output"] == "Processed: test_input", (
f"Server didn't return expected output\nWebSocket client output: {response}"
)

# Send other types of payloads
# TODO: make this test abit more better
websocket.send("text_payload")
response = websocket.recv()
response = json.loads(response)
assert "error" in response, (
f"Server didn't return expected error for text payload\nWebSocket client output: {response}"
)
21 changes: 21 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,24 @@ async def test_openai_embedding_spec_with_batching(openai_embedding_request_data
assert len(resp2.json()["data"]) == 4, "Length of data should be 4"
assert len(resp1.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"
assert len(resp2.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"


# TODO: Find a way to test websocket spec
# Currently, the following test seems to get stuck somewhere in the async code.

# @pytest.mark.asyncio
# async def test_websocket_litapi():
# server = ls.LitServer(WebSocketLitAPI(), spec=ls.WebSocketSpec())

# with wrap_litserve_start(server) as server:
# async with LifespanManager(server.app) as manager:
# client = TestClient(server.app)

# with client.websocket_connect("/predict") as websocket:
# # Send a JSON payload
# payload = {"input": "test_input"}
# websocket.send_json(payload)

# # Receive the response (should work now that lifespan is running)
# response = websocket.receive_json()
# assert response["output"] == "Processed: test_input"
Loading