Skip to content

Commit 6083b70

Browse files
authored
fix OpenAI embedding spec for batching (#500)
* refactor: enhance OpenAIEmbeddingSpec with improved response handling and new utility methods - Introduced `get_num_items` method in `EmbeddingRequest` for better input management. - Refactored embedding response handling into `_handle_embedding_response` for clarity and error checking. - Updated endpoint registration to use `embeddings_endpoint` for consistency. - Improved logging and response validation in the embeddings endpoint. * update * update * refactor: enhance OpenAIEmbeddingSpec and TestAPI for improved input handling and response validation - Updated `TestAPI` to support batching in the `predict` method, allowing for multiple inputs. - Added `EMBEDDING_API_EXAMPLE_BATCHING` to provide guidance on using the OpenAI Embedding spec with batching. - Improved response validation in `OpenAIEmbeddingSpec` to handle mismatches between requested and returned embeddings. - Removed unused methods and streamlined the code for better clarity and maintainability. * fix: improve error handling in BatchedLoop and add OpenAI embedding test - Enhanced error logging in `BatchedLoop` to provide clearer messages when output length mismatches expected input count. - Updated HTTPException to include a detailed message for better debugging. - Introduced `openai_embedding_with_batching.py` for end-to-end testing of the OpenAI embedding API with batching support. - Added assertions in `test_e2e_openai_embedding_with_batching` to validate model and embedding dimensions. * fix * fix * fix: update OpenAI embedding prediction to handle single and multiple inputs - Modified `predict` method in `TestEmbedAPI` to support both single and batch inputs by adjusting the random embedding generation. - Added comprehensive tests for OpenAI embedding spec, covering single input, multiple inputs, usage validation, and error handling for missing or incorrect responses. - Ensured that the tests validate the expected structure and content of the API responses. * test: enhance OpenAI embedding tests for batching and error handling - Added tests for client-side batching to validate error responses when dynamic batching is used. - Improved assertions for status codes and response content in existing tests for better clarity. - Utilized `copy.deepcopy` to ensure request data integrity during concurrent tests. * update * fix ci * fix test
1 parent 1afd62c commit 6083b70

File tree

10 files changed

+354
-163
lines changed

10 files changed

+354
-163
lines changed

src/litserve/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def pre_setup(self, spec: Optional[LitSpec]):
176176

177177
if spec:
178178
self._spec = spec
179+
spec._max_batch_size = self.max_batch_size
179180
spec.pre_setup(self)
180181

181182
def set_logger_queue(self, queue: Queue):

src/litserve/loops/simple_loops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,14 @@ def run_batched_loop(
326326
outputs = lit_api.unbatch(y)
327327

328328
if len(outputs) != num_inputs:
329+
actual = len(outputs)
329330
logger.error(
330-
f"LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. "
331-
"Please check the predict/unbatch method of the LitAPI implementation."
331+
f"LitAPI.predict/unbatch returned {actual} outputs, but expected {num_inputs}. "
332+
"This suggests a possible issue in the predict or unbatch implementation.\n"
333+
"Hint: Ensure that LitAPI.predict returns a list with one prediction per input — "
334+
"the length of the returned list should match the number of inputs."
332335
)
333-
raise HTTPException(500, "Batch size mismatch")
336+
raise HTTPException(500, detail="Batch size mismatch")
334337

335338
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
336339
y_enc_list = []

src/litserve/specs/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from litserve.specs.openai import OpenAISpec
2-
from litserve.specs.openai_embedding import OpenAIEmbeddingSpec
1+
from litserve.specs.openai import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, OpenAISpec
2+
from litserve.specs.openai_embedding import EmbeddingRequest, EmbeddingResponse, OpenAIEmbeddingSpec
33

4-
__all__ = ["OpenAISpec", "OpenAIEmbeddingSpec"]
4+
__all__ = [
5+
"OpenAISpec",
6+
"OpenAIEmbeddingSpec",
7+
"EmbeddingRequest",
8+
"EmbeddingResponse",
9+
"ChatCompletionRequest",
10+
"ChatCompletionResponse",
11+
"ChatCompletionChunk",
12+
]

src/litserve/specs/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self):
2525
self._endpoints = []
2626

2727
self._server: LitServer = None
28+
self._max_batch_size = 1
2829

2930
@property
3031
def stream(self):

src/litserve/specs/openai_embedding.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
import asyncio
1515
import inspect
1616
import logging
17+
import sys
1718
import time
1819
import uuid
19-
from typing import TYPE_CHECKING, List, Literal, Optional, Union
20+
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
2021

2122
from fastapi import HTTPException, Request, Response, status
2223
from fastapi import status as status_code
@@ -28,6 +29,9 @@
2829
logger = logging.getLogger(__name__)
2930

3031
if TYPE_CHECKING:
32+
import numpy as np
33+
import torch
34+
3135
from litserve import LitServer
3236

3337

@@ -38,6 +42,14 @@ class EmbeddingRequest(BaseModel):
3842
encoding_format: Literal["float", "base64"] = "float"
3943
user: Optional[str] = None
4044

45+
def get_num_items(self) -> int:
46+
"""Return the number of sentences or tokens in the input."""
47+
if isinstance(self.input, list):
48+
if isinstance(self.input[0], list):
49+
return len(self.input[0])
50+
return len(self.input)
51+
return 1
52+
4153
def ensure_list(self):
4254
return self.input if isinstance(self.input, list) else [self.input]
4355

@@ -66,34 +78,54 @@ class EmbeddingResponse(BaseModel):
6678
```python
6779
import numpy as np
6880
from typing import List
69-
from litserve import LitAPI, OpenAIEmbeddingSpec
81+
from litserve.specs import OpenAIEmbeddingSpec, EmbeddingRequest
82+
import litserve as ls
7083
71-
class TestAPI(LitAPI):
84+
class TestAPI(ls.LitAPI):
7285
def setup(self, device):
7386
self.model = None
7487
75-
def decode_request(self, request) -> List[str]:
76-
return request.ensure_list()
88+
def predict(self, inputs) -> List[List[float]]:
89+
# inputs is a string
90+
return np.random.rand(1, 768).tolist()
7791
78-
def predict(self, x) -> List[List[float]]:
79-
return np.random.rand(len(x), 768).tolist()
80-
81-
def encode_response(self, output) -> dict:
82-
return {"embeddings": output}
8392
8493
if __name__ == "__main__":
85-
import litserve as ls
8694
server = ls.LitServer(TestAPI(), spec=OpenAIEmbeddingSpec())
8795
server.run()
8896
```
8997
"""
9098

99+
EMBEDDING_API_EXAMPLE_BATCHING = """
100+
Please follow the example below for guidance on how to use the OpenAI Embedding spec with batching:
101+
102+
```python
103+
import numpy as np
104+
from typing import List
105+
from litserve.specs import OpenAIEmbeddingSpec, EmbeddingRequest
106+
import litserve as ls
107+
108+
class TestAPI(ls.LitAPI):
109+
def setup(self, device):
110+
self.model = None
111+
112+
def predict(self, inputs) -> List[List[float]]:
113+
# inputs is a list of texts (List[str])
114+
return np.random.rand(len(inputs), 768)
115+
116+
if __name__ == "__main__":
117+
api = TestAPI(max_batch_size=2, batch_timeout=0.4)
118+
server = ls.LitServer(api, spec=OpenAIEmbeddingSpec())
119+
server.run()
120+
```
121+
"""
122+
91123

92124
class OpenAIEmbeddingSpec(LitSpec):
93125
def __init__(self):
94126
super().__init__()
95127
# register the endpoint
96-
self.add_endpoint("/v1/embeddings", self.embeddings, ["POST"])
128+
self.add_endpoint("/v1/embeddings", self.embeddings_endpoint, ["POST"])
97129
self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"])
98130

99131
def setup(self, server: "LitServer"):
@@ -124,17 +156,21 @@ def setup(self, server: "LitServer"):
124156
print("OpenAI Embedding Spec is ready.")
125157

126158
def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]:
127-
return request.ensure_list()
159+
return request.input
128160

129-
def encode_response(self, output: List[List[float]], context_kwargs: Optional[dict] = None) -> dict:
161+
def encode_response(
162+
self, output: List[List[float]], context_kwargs: Optional[dict] = None
163+
) -> Union[dict, EmbeddingResponse]:
130164
usage = {
131165
"prompt_tokens": context_kwargs.get("prompt_tokens", 0) if context_kwargs else 0,
132166
"total_tokens": context_kwargs.get("total_tokens", 0) if context_kwargs else 0,
133167
}
134168
return {"embeddings": output} | usage
135169

136-
def _validate_response(self, response: dict) -> None:
137-
if not isinstance(response, dict):
170+
def _validate_response(self, response: Union[dict, List[Embedding], Any]) -> None:
171+
if isinstance(response, list) and all(isinstance(item, Embedding) for item in response):
172+
return
173+
if not isinstance(response, (dict, EmbeddingResponse)):
138174
raise ValueError(
139175
f"Expected response to be a dictionary, but got type {type(response)}.",
140176
"The response should be a dictionary to ensure proper compatibility with the OpenAIEmbeddingSpec.\n\n"
@@ -152,8 +188,60 @@ def _validate_response(self, response: dict) -> None:
152188
f"{EMBEDDING_API_EXAMPLE}"
153189
)
154190

155-
async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
191+
def _handle_embedding_response(
192+
self, embeddings: Union[List, "np.ndarray", "torch.Tensor", "List[List[float]]"], num_items: int = 1
193+
) -> List[Embedding]:
194+
ndim = None
195+
if "torch" in sys.modules:
196+
import torch
197+
198+
if isinstance(embeddings, torch.Tensor):
199+
ndim = embeddings.ndim
200+
if "numpy" in sys.modules:
201+
import numpy as np
202+
203+
if isinstance(embeddings, np.ndarray):
204+
ndim = embeddings.ndim
205+
206+
# expand_dims for torch.Tensor or np.ndarray
207+
if ndim == 1:
208+
embeddings = embeddings[None, :]
209+
210+
if ndim is not None:
211+
embeddings = embeddings.tolist()
212+
213+
# expand dims for list of floats
214+
if isinstance(embeddings, (list, tuple)) and isinstance(embeddings[0], (int, float)):
215+
embeddings = [embeddings]
216+
217+
# check if we have total num_items number of embeddings vectors
218+
num_response_items = len(embeddings)
219+
if num_response_items != num_items:
220+
logger.debug("mismatch between number of requested and returned embeddings: %s", embeddings)
221+
raise ValueError(
222+
f"Mismatch between requested and returned embeddings: "
223+
f"expected {num_items}, but got {num_response_items}. "
224+
f"This may indicate a bug in the LitAPI embedding implementation."
225+
)
226+
227+
result = []
228+
for i, embedding in enumerate(embeddings):
229+
result.append(Embedding(index=i, embedding=embedding))
230+
231+
return result
232+
233+
async def embeddings_endpoint(self, request: EmbeddingRequest) -> EmbeddingResponse:
156234
response_queue_id = self.response_queue_id
235+
num_items = request.get_num_items()
236+
if num_items > 1 and self._max_batch_size > 1:
237+
raise HTTPException(
238+
status_code=400,
239+
detail=(
240+
"The OpenAIEmbedding spec does not support dynamic batching when client-side batching is used. "
241+
"To resolve this, either set `max_batch_size=1` or send a single input from the client."
242+
),
243+
)
244+
157245
logger.debug("Received embedding request: %s", request)
158246
uid = uuid.uuid4()
159247
event = asyncio.Event()
@@ -174,9 +262,9 @@ async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
174262
logger.debug(response)
175263

176264
self._validate_response(response)
265+
data: List[Embedding] = self._handle_embedding_response(response["embeddings"], num_items)
177266

178267
usage = UsageInfo(**response)
179-
data = [Embedding(index=i, embedding=embedding) for i, embedding in enumerate(response["embeddings"])]
180268

181269
return EmbeddingResponse(data=data, model=request.model, usage=usage)
182270

src/litserve/test_examples/openai_embedding_spec_example.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@ class TestEmbedAPI(LitAPI):
99
def setup(self, device):
1010
self.model = None
1111

12-
def decode_request(self, request) -> List[str]:
13-
return request.ensure_list()
14-
1512
def predict(self, x) -> List[List[float]]:
16-
return np.random.rand(len(x), 768).tolist()
13+
n = len(x) if isinstance(x, list) else 1
14+
return np.random.rand(n, 768).tolist()
1715

1816
def encode_response(self, output) -> dict:
1917
return {"embeddings": output}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
3+
import litserve as ls
4+
5+
6+
class EmbeddingsAPI(ls.LitAPI):
7+
def setup(self, device):
8+
def model(x):
9+
return np.random.rand(len(x), 768)
10+
11+
self.model = model
12+
13+
def predict(self, inputs):
14+
return self.model(inputs)
15+
16+
17+
if __name__ == "__main__":
18+
api = EmbeddingsAPI(max_batch_size=10, batch_timeout=2)
19+
server = ls.LitServer(api, spec=ls.OpenAIEmbeddingSpec())
20+
server.run(port=8000)

tests/e2e/test_e2e.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import subprocess
1717
import time
18+
from concurrent.futures import ThreadPoolExecutor
1819
from functools import wraps
1920

2021
import psutil
@@ -390,3 +391,26 @@ def test_e2e_default_async_streaming():
390391
outputs.append(json.loads(line.decode("utf-8"))["output"])
391392

392393
assert outputs == list(range(10)), "server didn't return expected output"
394+
395+
396+
@e2e_from_file("tests/e2e/openai_embedding_with_batching.py")
397+
def test_e2e_openai_embedding_with_batching():
398+
model = "text-embedding-3-large"
399+
client = OpenAI(
400+
base_url="http://127.0.0.1:8000/v1",
401+
api_key="lit", # required, but unused
402+
)
403+
futures = []
404+
with ThreadPoolExecutor(max_workers=2) as executor:
405+
futures.append(executor.submit(client.embeddings.create, model=model, input=["This is the first request"]))
406+
futures.append(executor.submit(client.embeddings.create, model=model, input=["This is the second request"]))
407+
futures.append(executor.submit(client.embeddings.create, model=model, input=["This is the first request"]))
408+
futures.append(executor.submit(client.embeddings.create, model=model, input=["This is the second request"]))
409+
410+
responses = [future.result() for future in futures]
411+
for response in responses:
412+
assert response.model == model, f"Expected model to be {model} but got {response.model}"
413+
assert len(response.data[0].embedding) == 768, (
414+
f"Expected 768 dimensions but got {len(response.data[0].embedding)}"
415+
)
416+
assert len(responses) == 4, f"Expected 4 responses but got {len(responses)}"

0 commit comments

Comments
 (0)