Skip to content
19 changes: 7 additions & 12 deletions src/litserve/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,23 @@
class _Connector:
def __init__(self, accelerator: str = "auto", devices: Union[List[int], int, str] = "auto"):
accelerator = self._sanitize_accelerator(accelerator)
if accelerator == "cpu":
self._accelerator = "cpu"
elif accelerator == "cuda":
self._accelerator = "cuda"
elif accelerator == "mps":
self._accelerator = "mps"

if accelerator in ("cpu", "cuda", "mps"):
self._accelerator = accelerator
elif accelerator == "auto":
self._accelerator = self._choose_auto_accelerator()
elif accelerator == "gpu":
self._accelerator = self._choose_gpu_accelerator_backend()

if devices == "auto":
self._devices = self._auto_device_count(self._accelerator)
self._devices = self._accelerator_device_count()
else:
self._devices = devices

self.check_devices_and_accelerators()

def check_devices_and_accelerators(self):
"""Check if the devices are in a valid fomra and raise an error if they are not."""
if self._accelerator in ["cuda", "mps"]:
if self._accelerator in ("cuda", "mps"):
if not isinstance(self._devices, int) and not (
isinstance(self._devices, list) and all(isinstance(device, int) for device in self._devices)
):
Expand All @@ -68,7 +63,7 @@ def _sanitize_accelerator(accelerator: Optional[str]):
accelerator = accelerator.lower()

if accelerator not in ["auto", "cpu", "mps", "cuda", "gpu", None]:
raise ValueError("accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'")
raise ValueError(f"accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'. Found: {accelerator}")

if accelerator is None:
return "auto"
Expand All @@ -80,8 +75,8 @@ def _choose_auto_accelerator(self):
return gpu_backend
return "cpu"

def _auto_device_count(self, accelerator) -> int:
if accelerator == "cuda":
def _accelerator_device_count(self) -> int:
if self._accelerator == "cuda":
return check_cuda_with_nvidia_smi()
return 1

Expand Down
17 changes: 8 additions & 9 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@


def get_default_loop(stream: bool, max_batch_size: int) -> _BaseLoop:
return (
BatchedStreamingLoop()
if stream and max_batch_size > 1
else StreamingLoop()
if stream
else BatchedLoop()
if max_batch_size > 1
else SingleLoop()
)
if stream:
if max_batch_size > 1:
return BatchedStreamingLoop()
return StreamingLoop()
else: # noqa: RET505
if max_batch_size > 1:
return BatchedLoop()
return SingleLoop()


def inference_worker(
Expand Down
4 changes: 2 additions & 2 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ async def info(request: Request) -> Response:
}
)

async def predict(request: self.request_type) -> self.response_type:
async def predict(request: self.request_type) -> self.response_type: # should be Any
self._callback_runner.trigger_event(
EventTypes.ON_REQUEST,
active_requests=self.active_requests,
Expand Down Expand Up @@ -477,7 +477,7 @@ async def predict(request: self.request_type) -> self.response_type:
self._callback_runner.trigger_event(EventTypes.ON_RESPONSE, litserver=self)
return response

async def stream_predict(request: self.request_type) -> self.response_type:
async def stream_predict(request: self.request_type) -> self.response_type: # should be Any
self._callback_runner.trigger_event(
EventTypes.ON_REQUEST,
active_requests=self.active_requests,
Expand Down
7 changes: 5 additions & 2 deletions src/litserve/specs/openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import time
import uuid
from typing import List, Literal, Optional, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Union

from fastapi import HTTPException, Request, Response, status
from fastapi import status as status_code
Expand All @@ -27,6 +27,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from litserve import LitServer


class EmbeddingRequest(BaseModel):
input: Union[str, List[str], List[int], List[List[int]]]
Expand Down Expand Up @@ -93,7 +96,7 @@ def __init__(self):
self.add_endpoint("/v1/embeddings", self.embeddings, ["POST"])
self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"])

def setup(self, server: "LitServer"): # noqa: F821
def setup(self, server: "LitServer"):
from litserve import LitAPI

super().setup(server)
Expand Down
2 changes: 0 additions & 2 deletions src/litserve/transport/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

# TODO: raise NotImplemented error for all methods


class MessageTransport(ABC):
@abstractmethod
Expand Down
Loading