Skip to content
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.7"
__version__ = "0.2.8.dev0"
__author__ = "Lightning-AI et al."
__author_email__ = "community@lightning.ai"
__license__ = "Apache-2.0"
Expand Down
42 changes: 10 additions & 32 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
import logging
import pickle
import signal
import sys
import time
from abc import ABC
Expand All @@ -27,8 +26,8 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus
from litserve.zmq_queue import Producer

logger = logging.getLogger(__name__)
# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing
Expand Down Expand Up @@ -152,7 +151,7 @@ def __call__(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -164,9 +163,7 @@ def __call__(

async def _wrapper():
logger.info("Running LitLoop in a asyncio event loop")
future = self.schedule_task(
lit_api, lit_spec, request_queue, max_batch_size, batch_timeout, response_queues
)
future = self.schedule_task(lit_api, lit_spec, request_queue, max_batch_size, batch_timeout, transport)
_ = event_loop.create_task(future)
while True:
try:
Expand All @@ -176,7 +173,7 @@ async def _wrapper():
device,
worker_id,
request_queue,
response_queues,
transport,
max_batch_size,
batch_timeout,
stream,
Expand All @@ -196,7 +193,7 @@ async def _wrapper():
device,
worker_id,
request_queue,
response_queues,
transport,
max_batch_size,
batch_timeout,
stream,
Expand All @@ -211,7 +208,7 @@ def run(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -223,9 +220,7 @@ def run(

class LitLoop(_BaseLoop):
def __init__(self):
self.producer: Optional[Producer] = None
self._context = {}
self._setup_signal_handlers()

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
batches, timed_out_uids = collate_requests(
Expand All @@ -247,32 +242,15 @@ def populate_context(self, lit_spec: LitSpec, request: Any):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
self, transport: MessageTransport, response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
if self.producer:
self.producer.put((uid, (response_data, status)), consumer_id=response_queue_id)
else:
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)
transport.send((uid, (response_data, status)), consumer_id=response_queue_id)

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
self, transport: MessageTransport, response_queue_id: int, uid: str, error: Exception
) -> None:
error = pickle.dumps(error)
self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR)

def __del__(self):
if self.producer:
self.producer.close()

def _setup_signal_handlers(self):
def cleanup_handler(signum=None, frame=None):
logging.debug("Worker process received shutdown signal")
if self.producer:
self.producer.close()
sys.exit(0)

signal.signal(signal.SIGINT, cleanup_handler)
signal.signal(signal.SIGTERM, cleanup_handler)
self.put_response(transport, response_queue_id, uid, error, LitAPIStatus.ERROR)


class DefaultLoop(LitLoop):
Expand Down
11 changes: 6 additions & 5 deletions src/litserve/loops/continuous_batching_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from litserve.callbacks import CallbackRunner
from litserve.loops.base import LitLoop
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,7 +186,7 @@ async def run(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand Down Expand Up @@ -214,19 +215,19 @@ async def run(

response_data = lit_api.format_encoded_response(response_data)
if status == LitAPIStatus.ERROR:
self.put_error_response(response_queues, response_queue_id, uid, response_data)
self.put_error_response(transport, response_queue_id, uid, response_data)
self.mark_completed(uid)
elif status == LitAPIStatus.FINISH_STREAMING:
self.put_response(response_queues, response_queue_id, uid, response_data, status)
self.put_response(transport, response_queue_id, uid, response_data, status)
self.mark_completed(uid)
else:
self.put_response(response_queues, response_queue_id, uid, response_data, status)
self.put_response(transport, response_queue_id, uid, response_data, status)

except Exception as e:
logger.exception(f"Error in continuous batching loop: {e}")
# Handle any errors by sending error responses for all tracked requests
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
self.put_error_response(transport, response_queue_id, uid, e)
self.response_queue_ids.clear()
self.active_sequences.clear()

Expand Down
15 changes: 4 additions & 11 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
# limitations under the License.
import logging
from queue import Queue
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import _BaseLoop
from litserve.loops.simple_loops import BatchedLoop, SingleLoop
from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import WorkerSetupStatus
from litserve.zmq_queue import Producer

logger = logging.getLogger(__name__)

Expand All @@ -45,15 +45,13 @@ def inference_worker(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
loop: Union[str, _BaseLoop],
use_zmq: bool,
zmq_addr: Optional[str],
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
try:
Expand All @@ -76,18 +74,13 @@ def inference_worker(
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

if use_zmq:
producer = Producer(address=zmq_addr)
producer.wait_for_subscribers(timeout=5)
loop.producer = producer

loop(
lit_api,
lit_spec,
device,
worker_id,
request_queue,
response_queues,
transport,
max_batch_size,
batch_timeout,
stream,
Expand Down
31 changes: 16 additions & 15 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import logging
import time
from queue import Empty, Queue
from typing import Dict, List, Optional
from typing import Dict, Optional

from fastapi import HTTPException

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import DefaultLoop, _inject_context, collate_requests
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus, PickleableHTTPException

logger = logging.getLogger(__name__)
Expand All @@ -33,7 +34,7 @@ def run_single_loop(
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
callback_runner: CallbackRunner,
):
while True:
Expand All @@ -51,7 +52,7 @@ def run_single_loop(
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
self.put_response(
response_queues=response_queues,
transport=transport,
response_queue_id=response_queue_id,
uid=uid,
response_data=(HTTPException(504, "Request timed out")),
Expand Down Expand Up @@ -87,7 +88,7 @@ def run_single_loop(
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
self.put_response(
response_queues=response_queues,
transport=transport,
response_queue_id=response_queue_id,
uid=uid,
response_data=y_enc,
Expand All @@ -96,7 +97,7 @@ def run_single_loop(

except HTTPException as e:
self.put_response(
response_queues=response_queues,
transport=transport,
response_queue_id=response_queue_id,
uid=uid,
response_data=PickleableHTTPException.from_exception(e),
Expand All @@ -110,7 +111,7 @@ def run_single_loop(
uid,
)
self.put_error_response(
response_queues=response_queues,
transport=transport,
response_queue_id=response_queue_id,
uid=uid,
error=e,
Expand All @@ -123,14 +124,14 @@ def __call__(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
self.run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
self.run_single_loop(lit_api, lit_spec, request_queue, transport, callback_runner)


class BatchedLoop(DefaultLoop):
Expand All @@ -139,7 +140,7 @@ def run_batched_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
Expand All @@ -159,7 +160,7 @@ def run_batched_loop(
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
self.put_response(
response_queues, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
transport, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
)

if not batches:
Expand Down Expand Up @@ -207,12 +208,12 @@ def run_batched_loop(
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

for response_queue_id, uid, y_enc in y_enc_list:
self.put_response(response_queues, response_queue_id, uid, y_enc, LitAPIStatus.OK)
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(
response_queues,
transport,
response_queue_id,
uid,
PickleableHTTPException.from_exception(e),
Expand All @@ -225,7 +226,7 @@ def run_batched_loop(
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_error_response(response_queues, response_queue_id, uid, e)
self.put_error_response(transport, response_queue_id, uid, e)

def __call__(
self,
Expand All @@ -234,7 +235,7 @@ def __call__(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -245,7 +246,7 @@ def __call__(
lit_api,
lit_spec,
request_queue,
response_queues,
transport,
max_batch_size,
batch_timeout,
callback_runner,
Expand Down
Loading
Loading