Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-parity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install LitServe
run: |
pip --version
pip install . torchvision jsonargparse uvloop tenacity -U -q -r _requirements/test.txt -U -q
pip install . torchvision jsonargparse tenacity -U -q -r _requirements/test.txt -U -q
pip list

- name: Parity test
Expand Down
1 change: 0 additions & 1 deletion _requirements/perf.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
uvloop
tenacity
jsonargparse
19 changes: 7 additions & 12 deletions src/litserve/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,23 @@
class _Connector:
def __init__(self, accelerator: str = "auto", devices: Union[List[int], int, str] = "auto"):
accelerator = self._sanitize_accelerator(accelerator)
if accelerator == "cpu":
self._accelerator = "cpu"
elif accelerator == "cuda":
self._accelerator = "cuda"
elif accelerator == "mps":
self._accelerator = "mps"

if accelerator in ("cpu", "cuda", "mps"):
self._accelerator = accelerator
elif accelerator == "auto":
self._accelerator = self._choose_auto_accelerator()
elif accelerator == "gpu":
self._accelerator = self._choose_gpu_accelerator_backend()

if devices == "auto":
self._devices = self._auto_device_count(self._accelerator)
self._devices = self._accelerator_device_count()
else:
self._devices = devices

self.check_devices_and_accelerators()

def check_devices_and_accelerators(self):
"""Check if the devices are in a valid fomra and raise an error if they are not."""
if self._accelerator in ["cuda", "mps"]:
if self._accelerator in ("cuda", "mps"):
if not isinstance(self._devices, int) and not (
isinstance(self._devices, list) and all(isinstance(device, int) for device in self._devices)
):
Expand All @@ -68,7 +63,7 @@ def _sanitize_accelerator(accelerator: Optional[str]):
accelerator = accelerator.lower()

if accelerator not in ["auto", "cpu", "mps", "cuda", "gpu", None]:
raise ValueError("accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'")
raise ValueError(f"accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'. Found: {accelerator}")

if accelerator is None:
return "auto"
Expand All @@ -80,8 +75,8 @@ def _choose_auto_accelerator(self):
return gpu_backend
return "cpu"

def _auto_device_count(self, accelerator) -> int:
if accelerator == "cuda":
def _accelerator_device_count(self) -> int:
if self._accelerator == "cuda":
return check_cuda_with_nvidia_smi()
return 1

Expand Down
Loading