|
14 | 14 | import inspect |
15 | 15 | import io |
16 | 16 | import json |
| 17 | +import re |
17 | 18 | import threading |
18 | 19 | import time |
19 | 20 | from queue import Queue |
|
28 | 29 | from litserve import LitAPI |
29 | 30 | from litserve.callbacks import CallbackRunner |
30 | 31 | from litserve.loops import ( |
| 32 | + ContinuousBatchingLoop, |
| 33 | + DefaultLoop, |
| 34 | + LitLoop, |
| 35 | + Output, |
31 | 36 | _BaseLoop, |
32 | 37 | inference_worker, |
| 38 | + notify_timed_out_requests, |
33 | 39 | run_batched_loop, |
34 | 40 | run_batched_streaming_loop, |
35 | 41 | run_single_loop, |
@@ -495,3 +501,153 @@ def test_get_default_loop(): |
495 | 501 | assert isinstance(loop, ls.loops.BatchedStreamingLoop), ( |
496 | 502 | "BatchedStreamingLoop must be returned when stream=True and max_batch_size>1" |
497 | 503 | ) |
| 504 | + |
| 505 | + |
| 506 | +@pytest.fixture |
| 507 | +def lit_loop_setup(): |
| 508 | + lit_loop = LitLoop() |
| 509 | + lit_api = MagicMock(request_timeout=0.1) |
| 510 | + request_queue = Queue() |
| 511 | + return lit_loop, lit_api, request_queue |
| 512 | + |
| 513 | + |
| 514 | +def test_lit_loop_get_batch_requests(lit_loop_setup): |
| 515 | + lit_loop, lit_api, request_queue = lit_loop_setup |
| 516 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) |
| 517 | + request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0})) |
| 518 | + batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001) |
| 519 | + assert len(batches) == 2 |
| 520 | + assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})] |
| 521 | + assert timed_out_uids == [] |
| 522 | + |
| 523 | + |
| 524 | +def test_lit_loop_get_request(lit_loop_setup): |
| 525 | + lit_loop, _, request_queue = lit_loop_setup |
| 526 | + t = time.monotonic() |
| 527 | + request_queue.put((0, "UUID-001", t, {"input": 4.0})) |
| 528 | + response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1) |
| 529 | + assert uid == "UUID-001" |
| 530 | + assert response_queue_id == 0 |
| 531 | + assert timestamp == t |
| 532 | + assert x_enc == {"input": 4.0} |
| 533 | + assert lit_loop.get_request(request_queue, timeout=0.001) is None |
| 534 | + |
| 535 | + |
| 536 | +def test_lit_loop_put_response(lit_loop_setup): |
| 537 | + lit_loop, _, request_queue = lit_loop_setup |
| 538 | + response_queues = [Queue()] |
| 539 | + lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK) |
| 540 | + response = response_queues[0].get() |
| 541 | + assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK)) |
| 542 | + |
| 543 | + |
| 544 | +def test_notify_timed_out_requests(): |
| 545 | + response_queues = [Queue()] |
| 546 | + |
| 547 | + # Simulate timed out requests |
| 548 | + timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")] |
| 549 | + |
| 550 | + # Call the function to notify timed out requests |
| 551 | + notify_timed_out_requests(response_queues, timed_out_uids) |
| 552 | + |
| 553 | + # Check the responses in the response queue |
| 554 | + response_1 = response_queues[0].get() |
| 555 | + response_2 = response_queues[0].get() |
| 556 | + |
| 557 | + assert response_1[0] == "UUID-001" |
| 558 | + assert response_1[1][1] == LitAPIStatus.ERROR |
| 559 | + assert isinstance(response_1[1][0], HTTPException) |
| 560 | + assert response_2[0] == "UUID-002" |
| 561 | + assert isinstance(response_2[1][0], HTTPException) |
| 562 | + assert response_2[1][1] == LitAPIStatus.ERROR |
| 563 | + |
| 564 | + |
| 565 | +class ContinuousBatchingAPI(ls.LitAPI): |
| 566 | + def setup(self, spec: Optional[LitSpec]): |
| 567 | + self.model = {} |
| 568 | + |
| 569 | + def add_request(self, uid: str, request): |
| 570 | + self.model[uid] = {"outputs": list(range(5))} |
| 571 | + |
| 572 | + def decode_request(self, input: str): |
| 573 | + return input |
| 574 | + |
| 575 | + def encode_response(self, output: str): |
| 576 | + return {"output": output} |
| 577 | + |
| 578 | + def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]: |
| 579 | + outputs = [] |
| 580 | + for k in self.model: |
| 581 | + v = self.model[k] |
| 582 | + if v["outputs"]: |
| 583 | + o = v["outputs"].pop(0) |
| 584 | + outputs.append(Output(k, o, LitAPIStatus.OK)) |
| 585 | + keys = list(self.model.keys()) |
| 586 | + for k in keys: |
| 587 | + if k not in [o.uid for o in outputs]: |
| 588 | + outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING)) |
| 589 | + del self.model[k] |
| 590 | + return outputs |
| 591 | + |
| 592 | + |
| 593 | +@pytest.mark.parametrize( |
| 594 | + ("stream", "max_batch_size", "error_msg"), |
| 595 | + [ |
| 596 | + (True, 4, "`lit_api.unbatch` must generate values using `yield`."), |
| 597 | + (True, 1, "`lit_api.encode_response` must generate values using `yield`."), |
| 598 | + ], |
| 599 | +) |
| 600 | +def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg): |
| 601 | + lit_api = ls.test_examples.SimpleLitAPI() |
| 602 | + lit_api.stream = stream |
| 603 | + lit_api.max_batch_size = max_batch_size |
| 604 | + loop = DefaultLoop() |
| 605 | + with pytest.raises(ValueError, match=error_msg): |
| 606 | + loop.pre_setup(lit_api, None) |
| 607 | + |
| 608 | + |
| 609 | +@pytest.fixture |
| 610 | +def continuous_batching_setup(): |
| 611 | + lit_api = ContinuousBatchingAPI() |
| 612 | + lit_api.stream = True |
| 613 | + lit_api.request_timeout = 0.1 |
| 614 | + lit_api.pre_setup(2, None) |
| 615 | + lit_api.setup(None) |
| 616 | + request_queue = Queue() |
| 617 | + response_queues = [Queue()] |
| 618 | + loop = ContinuousBatchingLoop() |
| 619 | + return lit_api, loop, request_queue, response_queues |
| 620 | + |
| 621 | + |
| 622 | +def test_continuous_batching_pre_setup(continuous_batching_setup): |
| 623 | + lit_api, loop, request_queue, response_queues = continuous_batching_setup |
| 624 | + lit_api.stream = False |
| 625 | + with pytest.raises( |
| 626 | + ValueError, |
| 627 | + match=re.escape( |
| 628 | + "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" |
| 629 | + ), |
| 630 | + ): |
| 631 | + loop.pre_setup(lit_api, None) |
| 632 | + |
| 633 | + |
| 634 | +def test_continuous_batching_run(continuous_batching_setup): |
| 635 | + lit_api, loop, request_queue, response_queues = continuous_batching_setup |
| 636 | + request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"})) |
| 637 | + loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER) |
| 638 | + |
| 639 | + results = [] |
| 640 | + for i in range(5): |
| 641 | + response = response_queues[0].get() |
| 642 | + uid, (response_data, status) = response |
| 643 | + o = json.loads(response_data)["output"] |
| 644 | + assert o == i |
| 645 | + assert status == LitAPIStatus.OK |
| 646 | + assert uid == "UUID-001" |
| 647 | + results.append(o) |
| 648 | + assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4" |
| 649 | + response = response_queues[0].get() |
| 650 | + uid, (response_data, status) = response |
| 651 | + o = json.loads(response_data)["output"] |
| 652 | + assert o == "" |
| 653 | + assert status == LitAPIStatus.FINISH_STREAMING |
0 commit comments