Skip to content

Commit b3d4a97

Browse files
authored
fix: update type hints for _default_unbatch and _spec attributes in LitAPI class (#435)
* fix: update type hints for _default_unbatch and _spec attributes in LitAPI class * fix: improve error message for unimplemented unbatch method in LitAPI class and add corresponding test
1 parent 79e6dd1 commit b3d4a97

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/litserve/api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616
from abc import ABC, abstractmethod
1717
from queue import Queue
18-
from typing import Optional
18+
from typing import Callable, Optional
1919

2020
from pydantic import BaseModel
2121

@@ -24,8 +24,8 @@
2424

2525
class LitAPI(ABC):
2626
_stream: bool = False
27-
_default_unbatch: callable = None
28-
_spec: LitSpec = None
27+
_default_unbatch: Optional[Callable] = None
28+
_spec: Optional[LitSpec] = None
2929
_device: Optional[str] = None
3030
_logger_queue: Optional[Queue] = None
3131
request_timeout: Optional[float] = None
@@ -76,6 +76,11 @@ def _unbatch_stream(self, output_stream):
7676

7777
def unbatch(self, output):
7878
"""Convert a batched output to a list of outputs."""
79+
if self._default_unbatch is None:
80+
raise ValueError(
81+
"Default implementation for `LitAPI.unbatch` method was not found. "
82+
"Please implement the `LitAPI.unbatch` method."
83+
)
7984
return self._default_unbatch(output)
8085

8186
def encode_response(self, output, **kwargs):

tests/test_litapi.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def test_default_batch_unbatch():
7272
assert api.unbatch(output) == inputs, "Default unbatch should not change input"
7373

7474

75+
def test_default_unbatch_not_implemented():
76+
api = TestDefaultBatchedAPI()
77+
with pytest.raises(ValueError, match="Default implementation for `LitAPI.unbatch` method was not found."):
78+
api.unbatch(None)
79+
80+
7581
class TestStreamAPIBatched(TestStreamAPI):
7682
def predict(self, x):
7783
for i in range(4):

0 commit comments

Comments
 (0)