Skip to content

Commit ccc322a

Browse files
authored
fix custom exceptions (#425)
* always pickle exceptions * fix * fix * fix
1 parent 38d7169 commit ccc322a

File tree

5 files changed

+26
-7
lines changed

5 files changed

+26
-7
lines changed

src/litserve/loops/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import asyncio
1515
import inspect
1616
import logging
17+
import pickle
1718
import signal
1819
import sys
1920
import time
@@ -256,6 +257,7 @@ def put_response(
256257
def put_error_response(
257258
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
258259
) -> None:
260+
error = pickle.dumps(error)
259261
self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR)
260262

261263
def __del__(self):

src/litserve/loops/simple_loops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,11 @@ def run_single_loop(
109109
"Please check the error trace for more details.",
110110
uid,
111111
)
112-
self.put_response(
112+
self.put_error_response(
113113
response_queues=response_queues,
114114
response_queue_id=response_queue_id,
115115
uid=uid,
116-
response_data=e,
117-
status=LitAPIStatus.ERROR,
116+
error=e,
118117
)
119118

120119
def __call__(
@@ -226,7 +225,7 @@ def run_batched_loop(
226225
"Please check the error trace for more details."
227226
)
228227
for response_queue_id, uid in zip(response_queue_ids, uids):
229-
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
228+
self.put_error_response(response_queues, response_queue_id, uid, e)
230229

231230
def __call__(
232231
self,

src/litserve/loops/streaming_loops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def run_streaming_loop(
102102
"Please check the error trace for more details.",
103103
uid,
104104
)
105-
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
105+
self.put_error_response(response_queues, response_queue_id, uid, e)
106106

107107
def __call__(
108108
self,
@@ -207,7 +207,7 @@ def run_batched_streaming_loop(
207207
"Please check the error trace for more details."
208208
)
209209
for response_queue_id, uid in zip(response_queue_ids, uids):
210-
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
210+
self.put_error_response(response_queues, response_queue_id, uid, e)
211211

212212
def __call__(
213213
self,

tests/test_loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def put(self, item, block=False, timeout=None):
132132
response, status = args
133133
if status == LitAPIStatus.FINISH_STREAMING:
134134
raise StopIteration("interrupt iteration")
135-
if status == LitAPIStatus.ERROR and isinstance(response, StopIteration):
135+
if status == LitAPIStatus.ERROR:
136136
assert self.count // 2 == self.num_streamed_outputs, (
137137
f"Loop count must have incremented for {self.num_streamed_outputs} times."
138138
)

tests/test_simple.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,21 @@ def test_concurrent_requests(lit_server):
244244
assert response.json() == {"output": i**2}, "Server returns square of the input number"
245245
count += 1
246246
assert count == n_requests
247+
248+
249+
class CustomError(Exception):
250+
def __init__(self, arg1, arg2, arg3):
251+
super().__init__("Test exception")
252+
253+
254+
class ExceptionAPI(SimpleLitAPI):
255+
def predict(self, x):
256+
raise CustomError("This", "is", "a test")
257+
258+
259+
def test_exception():
260+
server = LitServer(ExceptionAPI(), accelerator="cpu", devices=1)
261+
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
262+
response = client.post("/predict", json={"input": 4.0})
263+
assert response.status_code == 500
264+
assert response.json() == {"detail": "Internal Server Error"}

0 commit comments

Comments
 (0)