Skip to content

Commit fcce973

Browse files
aniketmauryaBorda
andauthored
add tests for continuous batching and Default loops (#396)
* add test * update * fix * addt test * update * bump version * Update src/litserve/loops.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent e71f38b commit fcce973

File tree

3 files changed

+163
-6
lines changed

3 files changed

+163
-6
lines changed

src/litserve/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
__version__ = "0.2.6.dev1"
14+
__version__ = "0.2.6.dev2"
1515
__author__ = "Lightning-AI et al."
1616
__author_email__ = "community@lightning.ai"
1717
__license__ = "Apache-2.0"

src/litserve/loops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,6 @@ def __init__(self):
495495
self._context = {}
496496

497497
def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
498-
if max_batch_size <= 1:
499-
raise ValueError("max_batch_size must be greater than 1")
500-
501498
batches, timed_out_uids = collate_requests(
502499
lit_api,
503500
request_queue,
@@ -507,8 +504,10 @@ def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_si
507504
return batches, timed_out_uids
508505

509506
def get_request(self, request_queue: Queue, timeout: float = 1.0):
510-
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
511-
return response_queue_id, uid, timestamp, x_enc
507+
try:
508+
return request_queue.get(timeout=timeout)
509+
except Empty:
510+
return None
512511

513512
def populate_context(self, lit_spec: LitSpec, request: Any):
514513
if lit_spec and hasattr(lit_spec, "populate_context"):
@@ -751,6 +750,8 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:
751750

752751
def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
753752
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
753+
if hasattr(lit_api, "add_request"):
754+
lit_api.add_request(uid, request)
754755
decoded_request = lit_api.decode_request(request)
755756
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}
756757

tests/test_loops.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
import io
1616
import json
17+
import re
1718
import threading
1819
import time
1920
from queue import Queue
@@ -28,8 +29,13 @@
2829
from litserve import LitAPI
2930
from litserve.callbacks import CallbackRunner
3031
from litserve.loops import (
32+
ContinuousBatchingLoop,
33+
DefaultLoop,
34+
LitLoop,
35+
Output,
3136
_BaseLoop,
3237
inference_worker,
38+
notify_timed_out_requests,
3339
run_batched_loop,
3440
run_batched_streaming_loop,
3541
run_single_loop,
@@ -495,3 +501,153 @@ def test_get_default_loop():
495501
assert isinstance(loop, ls.loops.BatchedStreamingLoop), (
496502
"BatchedStreamingLoop must be returned when stream=True and max_batch_size>1"
497503
)
504+
505+
506+
@pytest.fixture
507+
def lit_loop_setup():
508+
lit_loop = LitLoop()
509+
lit_api = MagicMock(request_timeout=0.1)
510+
request_queue = Queue()
511+
return lit_loop, lit_api, request_queue
512+
513+
514+
def test_lit_loop_get_batch_requests(lit_loop_setup):
515+
lit_loop, lit_api, request_queue = lit_loop_setup
516+
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))
517+
request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0}))
518+
batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001)
519+
assert len(batches) == 2
520+
assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})]
521+
assert timed_out_uids == []
522+
523+
524+
def test_lit_loop_get_request(lit_loop_setup):
525+
lit_loop, _, request_queue = lit_loop_setup
526+
t = time.monotonic()
527+
request_queue.put((0, "UUID-001", t, {"input": 4.0}))
528+
response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1)
529+
assert uid == "UUID-001"
530+
assert response_queue_id == 0
531+
assert timestamp == t
532+
assert x_enc == {"input": 4.0}
533+
assert lit_loop.get_request(request_queue, timeout=0.001) is None
534+
535+
536+
def test_lit_loop_put_response(lit_loop_setup):
537+
lit_loop, _, request_queue = lit_loop_setup
538+
response_queues = [Queue()]
539+
lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK)
540+
response = response_queues[0].get()
541+
assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))
542+
543+
544+
def test_notify_timed_out_requests():
545+
response_queues = [Queue()]
546+
547+
# Simulate timed out requests
548+
timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")]
549+
550+
# Call the function to notify timed out requests
551+
notify_timed_out_requests(response_queues, timed_out_uids)
552+
553+
# Check the responses in the response queue
554+
response_1 = response_queues[0].get()
555+
response_2 = response_queues[0].get()
556+
557+
assert response_1[0] == "UUID-001"
558+
assert response_1[1][1] == LitAPIStatus.ERROR
559+
assert isinstance(response_1[1][0], HTTPException)
560+
assert response_2[0] == "UUID-002"
561+
assert isinstance(response_2[1][0], HTTPException)
562+
assert response_2[1][1] == LitAPIStatus.ERROR
563+
564+
565+
class ContinuousBatchingAPI(ls.LitAPI):
566+
def setup(self, spec: Optional[LitSpec]):
567+
self.model = {}
568+
569+
def add_request(self, uid: str, request):
570+
self.model[uid] = {"outputs": list(range(5))}
571+
572+
def decode_request(self, input: str):
573+
return input
574+
575+
def encode_response(self, output: str):
576+
return {"output": output}
577+
578+
def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]:
579+
outputs = []
580+
for k in self.model:
581+
v = self.model[k]
582+
if v["outputs"]:
583+
o = v["outputs"].pop(0)
584+
outputs.append(Output(k, o, LitAPIStatus.OK))
585+
keys = list(self.model.keys())
586+
for k in keys:
587+
if k not in [o.uid for o in outputs]:
588+
outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING))
589+
del self.model[k]
590+
return outputs
591+
592+
593+
@pytest.mark.parametrize(
594+
("stream", "max_batch_size", "error_msg"),
595+
[
596+
(True, 4, "`lit_api.unbatch` must generate values using `yield`."),
597+
(True, 1, "`lit_api.encode_response` must generate values using `yield`."),
598+
],
599+
)
600+
def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg):
601+
lit_api = ls.test_examples.SimpleLitAPI()
602+
lit_api.stream = stream
603+
lit_api.max_batch_size = max_batch_size
604+
loop = DefaultLoop()
605+
with pytest.raises(ValueError, match=error_msg):
606+
loop.pre_setup(lit_api, None)
607+
608+
609+
@pytest.fixture
610+
def continuous_batching_setup():
611+
lit_api = ContinuousBatchingAPI()
612+
lit_api.stream = True
613+
lit_api.request_timeout = 0.1
614+
lit_api.pre_setup(2, None)
615+
lit_api.setup(None)
616+
request_queue = Queue()
617+
response_queues = [Queue()]
618+
loop = ContinuousBatchingLoop()
619+
return lit_api, loop, request_queue, response_queues
620+
621+
622+
def test_continuous_batching_pre_setup(continuous_batching_setup):
623+
lit_api, loop, request_queue, response_queues = continuous_batching_setup
624+
lit_api.stream = False
625+
with pytest.raises(
626+
ValueError,
627+
match=re.escape(
628+
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
629+
),
630+
):
631+
loop.pre_setup(lit_api, None)
632+
633+
634+
def test_continuous_batching_run(continuous_batching_setup):
635+
lit_api, loop, request_queue, response_queues = continuous_batching_setup
636+
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
637+
loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER)
638+
639+
results = []
640+
for i in range(5):
641+
response = response_queues[0].get()
642+
uid, (response_data, status) = response
643+
o = json.loads(response_data)["output"]
644+
assert o == i
645+
assert status == LitAPIStatus.OK
646+
assert uid == "UUID-001"
647+
results.append(o)
648+
assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4"
649+
response = response_queues[0].get()
650+
uid, (response_data, status) = response
651+
o = json.loads(response_data)["output"]
652+
assert o == ""
653+
assert status == LitAPIStatus.FINISH_STREAMING

0 commit comments

Comments
 (0)