diff --git a/changes/8388.enhance.md b/changes/8388.enhance.md new file mode 100644 index 00000000000..72265515e79 --- /dev/null +++ b/changes/8388.enhance.md @@ -0,0 +1 @@ +Enable ANN001 linter rule - Add type annotations to function arguments \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 95db0a4e3fd..0a9a3cc469a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,13 +171,8 @@ ignore = [ "A001", # builtin-variable-shadowing - mostly id/filter/input in GraphQL (75 violations) "A002", # builtin-argument-shadowing - mostly id/filter/input in GraphQL (297 violations) # ANN (annotations) - defer for gradual migration, mypy handles type checking - "ANN001", # missing-type-function-argument - 1336 violations, defer for gradual migration "ANN002", # missing-type-args - *args type annotation, low practical value "ANN003", # missing-type-kwargs - **kwargs type annotation, low practical value - # "ANN201", # missing-return-type-undocumented-public-function - ENABLING: analyzing patterns - # "ANN202", # missing-return-type-private-function - ENABLED: fixing violations incrementally - # "ANN205", # missing-return-type-static-method - ENABLED: fixing violations incrementally - # "ANN206", # missing-return-type-class-method - ENABLED: fixing violations incrementally "ANN401", # dynamically-typed-expression - Any needed for GraphQL resolvers, dynamic APIs ] diff --git a/src/ai/backend/accelerator/cuda_open/nvidia.py b/src/ai/backend/accelerator/cuda_open/nvidia.py index cb1fe8944a7..1b6d2e1ffc6 100644 --- a/src/ai/backend/accelerator/cuda_open/nvidia.py +++ b/src/ai/backend/accelerator/cuda_open/nvidia.py @@ -548,7 +548,7 @@ def _ensure_lib(cls) -> None: raise ImportError(f"Could not load the {cls.name} library!") @classmethod - def invoke(cls, func_name, *args, check_rc=True) -> int: + def invoke(cls, func_name: str, *args: Any, check_rc: bool = True) -> int: try: cls._ensure_lib() except ImportError: diff --git a/src/ai/backend/accelerator/ipu/setup.py b/src/ai/backend/accelerator/ipu/setup.py index 029134fa21e..ec7c03661f7 100644 --- a/src/ai/backend/accelerator/ipu/setup.py +++ b/src/ai/backend/accelerator/ipu/setup.py @@ -1,5 +1,6 @@ import sys from pathlib import Path +from zipfile import ZipInfo from Cython.Build import cythonize from setuptools import setup @@ -15,7 +16,9 @@ ) -def _filtered_writestr(self, zinfo_or_arcname, bytes, compress_type=None) -> None: +def _filtered_writestr( + self: WheelFile, zinfo_or_arcname: ZipInfo | str, bytes: bytes, compress_type: int | None = None +) -> None: global exclude_source_files if exclude_source_files: if isinstance(zinfo_or_arcname, str): diff --git a/src/ai/backend/accelerator/rebellions/atom/pci.py b/src/ai/backend/accelerator/rebellions/atom/pci.py index f078e07ea3f..ea6ab37deec 100644 --- a/src/ai/backend/accelerator/rebellions/atom/pci.py +++ b/src/ai/backend/accelerator/rebellions/atom/pci.py @@ -3,7 +3,7 @@ from pathlib import Path -async def read_sysfs(path, attr) -> str: +async def read_sysfs(path: Path, attr: str) -> str: def _blocking() -> str: return (path / attr).read_text().strip() diff --git a/src/ai/backend/accelerator/rocm/exception.py b/src/ai/backend/accelerator/rocm/exception.py index d6d0c91992d..5962f0c5878 100644 --- a/src/ai/backend/accelerator/rocm/exception.py +++ b/src/ai/backend/accelerator/rocm/exception.py @@ -1,20 +1,20 @@ class NoRocmDeviceError(Exception): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: self.message = message class GenericRocmError(Exception): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: self.message = message class RocmUtilFetchError(Exception): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: self.message = message class RocmMemFetchError(Exception): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: self.message = message diff --git a/src/ai/backend/accelerator/tenstorrent/n300/pci.py b/src/ai/backend/accelerator/tenstorrent/n300/pci.py index f078e07ea3f..ea6ab37deec 100644 --- a/src/ai/backend/accelerator/tenstorrent/n300/pci.py +++ b/src/ai/backend/accelerator/tenstorrent/n300/pci.py @@ -3,7 +3,7 @@ from pathlib import Path -async def read_sysfs(path, attr) -> str: +async def read_sysfs(path: Path, attr: str) -> str: def _blocking() -> str: return (path / attr).read_text().strip() diff --git a/src/ai/backend/accelerator/tpu/plugin.py b/src/ai/backend/accelerator/tpu/plugin.py index fa912dbed69..2b56b869c09 100644 --- a/src/ai/backend/accelerator/tpu/plugin.py +++ b/src/ai/backend/accelerator/tpu/plugin.py @@ -17,6 +17,7 @@ DiscretePropertyAllocMap, ) from ai.backend.agent.stats import ContainerMeasurement, NodeMeasurement, StatContext +from ai.backend.agent.types import Container from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import ( BinarySize, @@ -194,7 +195,9 @@ async def get_attached_devices( }) return attached_devices - async def restore_from_container(self, container, alloc_map) -> None: + async def restore_from_container( + self, container: Container, alloc_map: AbstractAllocMap + ) -> None: # TODO: implement! pass diff --git a/src/ai/backend/accelerator/tpu/tpu.py b/src/ai/backend/accelerator/tpu/tpu.py index 3eb710002b7..69dba539725 100644 --- a/src/ai/backend/accelerator/tpu/tpu.py +++ b/src/ai/backend/accelerator/tpu/tpu.py @@ -18,7 +18,7 @@ class libtpu: zone: ClassVar[Optional[str]] = None @classmethod - async def _run_ctpu(cls, cmd) -> str: + async def _run_ctpu(cls, cmd: list[str]) -> str: if not cls.zone: try: proc = await subprocess.create_subprocess_exec( diff --git a/src/ai/backend/account_manager/api/utils.py b/src/ai/backend/account_manager/api/utils.py index 28248403c32..d7720047fef 100644 --- a/src/ai/backend/account_manager/api/utils.py +++ b/src/ai/backend/account_manager/api/utils.py @@ -18,7 +18,7 @@ def auth_required(handler: Handler) -> Handler: @functools.wraps(handler) - async def wrapped(request, *args, **kwargs) -> web.StreamResponse: + async def wrapped(request: web.Request, *args, **kwargs) -> web.StreamResponse: if request.get("is_authorized", False): return await handler(request, *args, **kwargs) raise AuthorizationFailed("Unauthorized access") @@ -28,15 +28,15 @@ async def wrapped(request, *args, **kwargs) -> web.StreamResponse: return wrapped -def set_handler_attr(func, key, value) -> None: +def set_handler_attr(func: Callable, key: str, value: Any) -> None: attrs = getattr(func, "_backend_attrs", None) if attrs is None: attrs = {} attrs[key] = value - func._backend_attrs = attrs + func._backend_attrs = attrs # type: ignore[attr-defined] -def get_handler_attr(request, key, default=None) -> Any: +def get_handler_attr(request: web.Request, key: str, default: Any = None) -> Any: # When used in the aiohttp server-side codes, we should use # request.match_info.hanlder instead of handler passed to the middleware # functions because aiohttp wraps this original handler with functools.partial @@ -116,7 +116,7 @@ def ensure_stream_response_type[TAnyResponse: web.StreamResponse]( def pydantic_api_response_handler( handler: THandlerFuncWithoutParam, - is_deprecated=False, + is_deprecated: bool = False, ) -> Handler: """ Only for API handlers which does not require request body. @@ -141,7 +141,7 @@ def pydantic_api_handler[TParamModel: BaseModel, TQueryModel: BaseModel]( checker: type[TParamModel], loads: Callable[[str], Any] | None = None, query_param_checker: type[TQueryModel] | None = None, - is_deprecated=False, + is_deprecated: bool = False, ) -> Callable[[THandlerFuncWithParam], Handler]: def wrap( handler: THandlerFuncWithParam, diff --git a/src/ai/backend/account_manager/models/base.py b/src/ai/backend/account_manager/models/base.py index 24aa97d3790..cb01e8ec685 100644 --- a/src/ai/backend/account_manager/models/base.py +++ b/src/ai/backend/account_manager/models/base.py @@ -55,13 +55,13 @@ class GUID[UUID_SubType: uuid.UUID](TypeDecorator): uuid_subtype_func: ClassVar[Callable[[Any], Any]] = lambda v: v cache_ok = True - def load_dialect_impl(self, dialect) -> TypeDecorator: + def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) - return dialect.type_descriptor(CHAR(16)) + return cast(TypeDecorator, dialect.type_descriptor(UUID())) + return cast(TypeDecorator, dialect.type_descriptor(CHAR(16))) def process_bind_param( - self, value: UUID_SubType | uuid.UUID | None, dialect + self, value: UUID_SubType | uuid.UUID | None, dialect: Dialect ) -> str | bytes | None: # NOTE: EndpointId, SessionId, KernelId are *not* actual types defined as classes, # but a "virtual" type that is an identity function at runtime. @@ -82,7 +82,7 @@ def process_bind_param( case _: return uuid.UUID(value).bytes - def process_result_value(self, value: Any, dialect) -> UUID_SubType | None: + def process_result_value(self, value: Any, dialect: Dialect) -> UUID_SubType | None: if value is None: return value cls = type(self) @@ -134,9 +134,9 @@ def python_type(self) -> type[T_StrEnum]: class PasswordColumn(TypeDecorator): impl = VARCHAR - def process_bind_param(self, value, dialect) -> str: + def process_bind_param(self, value: Any, dialect: Dialect) -> str: return hash_password(value) -def IDColumn(name="id") -> sa.Column: +def IDColumn(name: str = "id") -> sa.Column: return sa.Column(name, GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index c814291b093..3cc2967a041 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -477,7 +477,7 @@ async def generate_accelerator_mounts( raise NotImplementedError @abstractmethod - def resolve_krunner_filepath(self, filename) -> Path: + def resolve_krunner_filepath(self, filename: str) -> Path: """ Return matching krunner path object for given filename. """ @@ -502,7 +502,7 @@ async def prepare_container( self, resource_spec: KernelResourceSpec, environ: Mapping[str, str], - service_ports, + service_ports: list[ServicePort], cluster_info: ClusterInfo, ) -> KernelObjectType: raise NotImplementedError @@ -512,8 +512,8 @@ async def start_container( self, kernel_obj: AbstractKernel, cmdargs: list[str], - resource_opts, - preopen_ports, + resource_opts: Optional[Mapping[str, Any]], + preopen_ports: list[int], cluster_info: ClusterInfo, ) -> Mapping[str, Any]: raise NotImplementedError @@ -582,9 +582,9 @@ async def mount_krunner( environ: MutableMapping[str, str], ) -> None: def _mount( - type, - src, - dst, + type: MountTypes, + src: str | Path, + dst: str | Path, ) -> None: resource_spec.mounts.append( self.get_runner_mount( @@ -3687,7 +3687,7 @@ async def shutdown_service(self, kernel_id: KernelId, service: str) -> None: async def commit( self, - reporter, + reporter: Any, kernel_id: KernelId, subdir: str, *, @@ -3719,7 +3719,7 @@ async def list_files(self, kernel_id: KernelId, path: str) -> dict[str, Any]: async def ping_kernel(self, kernel_id: KernelId) -> dict[str, float] | None: return await self.kernel_registry[kernel_id].ping() - async def save_last_registry(self, force=False) -> None: + async def save_last_registry(self, force: bool = False) -> None: await self._write_kernel_registry_to_recovery( self.kernel_registry, KernelRegistrySaveMetadata(force), diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 292c0a60f23..86095c53b15 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -36,7 +36,7 @@ import zmq import zmq.asyncio from aiodocker.docker import Docker, DockerContainer -from aiodocker.exceptions import DockerError +from aiodocker.exceptions import DockerContainerError, DockerError from aiodocker.types import PortInfo from aiomonitor.task import preserve_termination_log from aiotools import TaskGroup @@ -192,7 +192,7 @@ ) -async def get_extra_volumes(docker, lang) -> list[VolumeInfo]: +async def get_extra_volumes(docker: Docker, lang: str) -> list[VolumeInfo]: avail_volumes = (await docker.volumes.list())["Volumes"] if not avail_volumes: return [] @@ -265,14 +265,14 @@ async def _clean_scratch( pass -def _DockerError_reduce(self) -> tuple[type, tuple[Any, ...]]: +def _DockerError_reduce(self: DockerError) -> tuple[type, tuple[Any, ...]]: return ( type(self), (self.status, {"message": self.message}, *self.args), ) -def _DockerContainerError_reduce(self) -> tuple[type, tuple[Any, ...]]: +def _DockerContainerError_reduce(self: DockerContainerError) -> tuple[type, tuple[Any, ...]]: return ( type(self), (self.status, {"message": self.message}, self.container_id, *self.args), @@ -644,7 +644,7 @@ async def get_intrinsic_mounts(self) -> Sequence[Mount]: return mounts @override - def resolve_krunner_filepath(self, filename) -> Path: + def resolve_krunner_filepath(self, filename: str) -> Path: return Path( pkg_resources.resource_filename( "ai.backend.runner", @@ -1014,8 +1014,8 @@ async def start_container( self, kernel_obj: AbstractKernel, cmdargs: list[str], - resource_opts, - preopen_ports, + resource_opts: Optional[Mapping[str, Any]], + preopen_ports: list[int], cluster_info: ClusterInfo, ) -> Mapping[str, Any]: loop = current_loop() diff --git a/src/ai/backend/agent/docker/files.py b/src/ai/backend/agent/docker/files.py index 36848e332c4..41ef94381e1 100644 --- a/src/ai/backend/agent/docker/files.py +++ b/src/ai/backend/agent/docker/files.py @@ -17,7 +17,7 @@ log.info("Automatic ~/.output file S3 uploads is disabled.") -def relpath(path, base) -> Path: +def relpath(path: Path | str, base: Path | str) -> Path: return Path(path).resolve().relative_to(Path(base).resolve()) @@ -52,7 +52,7 @@ def scandir(root: Path, allowed_max_size: int) -> dict[Path, float]: return file_stats -def diff_file_stats(fs1, fs2) -> set[Path]: +def diff_file_stats(fs1: dict[Path, float], fs2: dict[Path, float]) -> set[Path]: k2 = set(fs2.keys()) k1 = set(fs1.keys()) new_files = k2 - k1 diff --git a/src/ai/backend/agent/docker/intrinsic.py b/src/ai/backend/agent/docker/intrinsic.py index 169ee7287b8..ca48877793c 100644 --- a/src/ai/backend/agent/docker/intrinsic.py +++ b/src/ai/backend/agent/docker/intrinsic.py @@ -403,9 +403,9 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: Docker, - device_alloc, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], ) -> Mapping[str, Any]: - cores = [*map(int, device_alloc["cpu"].keys())] + cores = [*map(int, device_alloc[SlotName("cpu")].keys())] sorted_core_ids = [*map(str, sorted(cores))] return { "HostConfig": { @@ -933,9 +933,9 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: Docker, - device_alloc, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], ) -> Mapping[str, Any]: - memory = sum(device_alloc["mem"].values()) + memory = sum(device_alloc[SlotName("mem")].values()) return { "HostConfig": { "MemorySwap": int(memory), # prevent using swap! @@ -1046,7 +1046,7 @@ async def get_capabilities(self) -> set[ContainerNetworkCapability]: return {ContainerNetworkCapability.GLOBAL} async def join_network( - self, kernel_config: KernelCreationConfig, cluster_info: ClusterInfo, **kwargs + self, kernel_config: KernelCreationConfig, cluster_info: ClusterInfo, **kwargs: Any ) -> dict[str, Any]: if _cluster_ssh_port_mapping := cluster_info.get("cluster_ssh_port_mapping"): return { @@ -1068,7 +1068,7 @@ async def leave_network(self, kernel: DockerKernel) -> None: pass async def prepare_port_forward( - self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs + self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs: Any ) -> None: host_ports = [p[0] for p in ports] scratch_dir = ( @@ -1091,7 +1091,7 @@ async def prepare_port_forward( ) async def expose_ports( - self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs + self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs: Any ) -> ContainerNetworkInfo: host_ports = [p[0] for p in ports] diff --git a/src/ai/backend/agent/docker/kernel.py b/src/ai/backend/agent/docker/kernel.py index 88397978b82..ece7d6b11dd 100644 --- a/src/ai/backend/agent/docker/kernel.py +++ b/src/ai/backend/agent/docker/kernel.py @@ -11,7 +11,7 @@ import shutil import subprocess import textwrap -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from pathlib import Path, PurePosixPath from typing import Any, Final, Optional, cast, override @@ -34,7 +34,7 @@ from ai.backend.common.dto.agent.response import CodeCompletionResp from ai.backend.common.events.dispatcher import EventProducer from ai.backend.common.lock import FileLock -from ai.backend.common.types import CommitStatus, KernelId, Sentinel +from ai.backend.common.types import CommitStatus, KernelId, Sentinel, SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.plugin.entrypoint import scan_entrypoints @@ -82,7 +82,7 @@ async def close(self) -> None: def __getstate__(self) -> Mapping[str, Any]: return super().__getstate__() - def __setstate__(self, props) -> None: + def __setstate__(self, props: MutableMapping[str, Any]) -> None: if "network_driver" not in props: props["network_driver"] = "bridge" super().__setstate__(props) @@ -186,8 +186,8 @@ async def check_duplicate_commit(self, kernel_id: KernelId, subdir: str) -> Comm @override async def commit( self, - kernel_id, - subdir, + kernel_id: KernelId, + subdir: str, *, canonical: str | None = None, filename: str | None = None, @@ -452,15 +452,15 @@ class DockerCodeRunner(AbstractCodeRunner): def __init__( self, - kernel_id, - session_id, - event_producer, + kernel_id: KernelId, + session_id: SessionId, + event_producer: EventProducer, *, - kernel_host, - repl_in_port, - repl_out_port, - exec_timeout=0, - client_features=None, + kernel_host: str, + repl_in_port: int, + repl_out_port: int, + exec_timeout: int = 0, + client_features: frozenset[str] | None = None, ) -> None: super().__init__( kernel_id, diff --git a/src/ai/backend/agent/docker/resources.py b/src/ai/backend/agent/docker/resources.py index 9329bbad5dc..f5a3d285bd7 100644 --- a/src/ai/backend/agent/docker/resources.py +++ b/src/ai/backend/agent/docker/resources.py @@ -83,7 +83,9 @@ async def scan_available_resources( return slots -async def get_resource_spec_from_container(container_info) -> Optional[KernelResourceSpec]: +async def get_resource_spec_from_container( + container_info: Mapping[str, Any], +) -> Optional[KernelResourceSpec]: for mount in container_info["HostConfig"]["Mounts"]: if mount["Target"] == "/home/config": async with aiofiles.open(Path(mount["Source"]) / "resource.txt") as f: diff --git a/src/ai/backend/agent/dummy/agent.py b/src/ai/backend/agent/dummy/agent.py index 8997950326d..127828bfd66 100644 --- a/src/ai/backend/agent/dummy/agent.py +++ b/src/ai/backend/agent/dummy/agent.py @@ -166,7 +166,7 @@ async def generate_accelerator_mounts( return [] @override - def resolve_krunner_filepath(self, filename) -> Path: + def resolve_krunner_filepath(self, filename: str) -> Path: return Path() @override @@ -185,7 +185,7 @@ async def prepare_container( self, resource_spec: KernelResourceSpec, environ: Mapping[str, str], - service_ports, + service_ports: list[ServicePort], cluster_info: ClusterInfo, ) -> DummyKernel: delay = self.creation_ctx_config["delay"]["spawn"] @@ -208,8 +208,8 @@ async def start_container( self, kernel_obj: AbstractKernel, cmdargs: list[str], - resource_opts, - preopen_ports, + resource_opts: Optional[Mapping[str, Any]], + preopen_ports: Sequence[int], cluster_info: ClusterInfo, ) -> Mapping[str, Any]: container_bind_host = self.local_config.container.bind_host diff --git a/src/ai/backend/agent/dummy/intrinsic.py b/src/ai/backend/agent/dummy/intrinsic.py index 368184ebd6a..5730695b081 100644 --- a/src/ai/backend/agent/dummy/intrinsic.py +++ b/src/ai/backend/agent/dummy/intrinsic.py @@ -11,6 +11,7 @@ AbstractAllocMap, AbstractComputeDevice, AbstractComputePlugin, + DeviceAllocation, DeviceSlotInfo, DiscretePropertyAllocMap, ) @@ -134,9 +135,9 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: aiodocker.docker.Docker, - device_alloc, + device_alloc: DeviceAllocation, ) -> Mapping[str, Any]: - cores = [*map(int, device_alloc["cpu"].keys())] + cores = [*map(int, device_alloc[SlotName("cpu")].keys())] sorted_core_ids = [*map(str, sorted(cores))] return { "HostConfig": { @@ -286,7 +287,7 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: aiodocker.docker.Docker, - device_alloc, + device_alloc: DeviceAllocation, ) -> Mapping[str, Any]: return {} diff --git a/src/ai/backend/agent/dummy/kernel.py b/src/ai/backend/agent/dummy/kernel.py index 97bb87fdf92..5c927551fe2 100644 --- a/src/ai/backend/agent/dummy/kernel.py +++ b/src/ai/backend/agent/dummy/kernel.py @@ -3,8 +3,8 @@ import asyncio import os from collections import OrderedDict -from collections.abc import Mapping, Sequence -from typing import Any, override +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, Optional, override from ai.backend.agent.kernel import AbstractCodeRunner, AbstractKernel, NextResult, ResultRecord from ai.backend.agent.resources import KernelResourceSpec @@ -12,7 +12,7 @@ from ai.backend.common.docker import ImageRef from ai.backend.common.dto.agent.response import CodeCompletionResp, CodeCompletionResult from ai.backend.common.events.dispatcher import EventProducer -from ai.backend.common.types import CommitStatus +from ai.backend.common.types import CommitStatus, KernelId, SessionId class DummyKernel(AbstractKernel): @@ -88,7 +88,7 @@ async def check_status(self) -> dict[str, Any]: return {} @override - async def get_completions(self, text, opts) -> CodeCompletionResp: + async def get_completions(self, text: str, opts: Mapping[str, Any]) -> CodeCompletionResp: delay = self.dummy_kernel_cfg["delay"]["get-completions"] await asyncio.sleep(delay) return CodeCompletionResp(result=CodeCompletionResult.success({"suggestions": []})) @@ -106,7 +106,7 @@ async def interrupt_kernel(self) -> dict[str, Any]: return {} @override - async def start_service(self, service, opts) -> dict[str, Any]: + async def start_service(self, service: str, opts: Mapping[str, Any]) -> dict[str, Any]: delay = self.dummy_kernel_cfg["delay"]["start-service"] await asyncio.sleep(delay) return {} @@ -118,12 +118,12 @@ async def start_model_service(self, model_service: Mapping[str, Any]) -> dict[st return {} @override - async def shutdown_service(self, service) -> None: + async def shutdown_service(self, service: str) -> None: delay = self.dummy_kernel_cfg["delay"]["shutdown-service"] await asyncio.sleep(delay) @override - async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: + async def check_duplicate_commit(self, kernel_id: KernelId, subdir: str) -> CommitStatus: if self.is_commiting: return CommitStatus.ONGOING return CommitStatus.READY @@ -131,8 +131,8 @@ async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: @override async def commit( self, - kernel_id, - subdir, + kernel_id: KernelId, + subdir: str, *, canonical: str | None = None, filename: str | None = None, @@ -189,15 +189,15 @@ class DummyCodeRunner(AbstractCodeRunner): def __init__( self, - kernel_id, - session_id, - event_producer, + kernel_id: KernelId, + session_id: SessionId, + event_producer: EventProducer, *, - kernel_host, - repl_in_port, - repl_out_port, - exec_timeout=0, - client_features=None, + kernel_host: str, + repl_in_port: int, + repl_out_port: int, + exec_timeout: float = 0, + client_features: Optional[frozenset[str]] = None, ) -> None: super().__init__( kernel_id, @@ -230,15 +230,15 @@ class DummyFakeCodeRunner(AbstractCodeRunner): def __init__( self, - kernel_id, - session_id, - event_producer, + kernel_id: KernelId, + session_id: SessionId, + event_producer: EventProducer, *, - kernel_host, - repl_in_port, - repl_out_port, - exec_timeout=0, - client_features=None, + kernel_host: str, + repl_in_port: int, + repl_out_port: int, + exec_timeout: float = 0, + client_features: Optional[frozenset[str]] = None, ) -> None: self.zctx = None self.input_sock = None @@ -266,7 +266,7 @@ def __init__( async def __ainit__(self) -> None: return - def __setstate__(self, props) -> None: + def __setstate__(self, props: MutableMapping[str, Any]) -> None: self.__dict__.update(props) self.zctx = None self.input_sock = None @@ -302,7 +302,7 @@ async def ping_status(self) -> None: return None @override - async def feed_batch(self, opts) -> None: + async def feed_batch(self, opts: Mapping[str, Any]) -> None: return None @override @@ -322,15 +322,17 @@ async def feed_and_get_status(self) -> None: return None @override - async def feed_and_get_completion(self, code_text, opts) -> CodeCompletionResult: + async def feed_and_get_completion( + self, code_text: str, opts: Mapping[str, Any] + ) -> CodeCompletionResult: return CodeCompletionResult.failure("not-implemented") @override - async def feed_start_model_service(self, model_info) -> dict[str, Any]: + async def feed_start_model_service(self, model_info: Mapping[str, Any]) -> dict[str, Any]: return {"status": "failed", "error": "not-implemented"} @override - async def feed_start_service(self, service_info) -> dict[str, Any]: + async def feed_start_service(self, service_info: Mapping[str, Any]) -> dict[str, Any]: return {"status": "failed", "error": "not-implemented"} @override @@ -345,7 +347,7 @@ def aggregate_console( return @override - async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult: + async def get_next_result(self, api_ver: int = 2, flush_timeout: float = 2.0) -> NextResult: return { "runId": self.current_run_id, "status": "finished", diff --git a/src/ai/backend/agent/exception.py b/src/ai/backend/agent/exception.py index cc6ee77ce85..08fe1d67d8b 100644 --- a/src/ai/backend/agent/exception.py +++ b/src/ai/backend/agent/exception.py @@ -119,7 +119,7 @@ def __init__(self, container_id: str, message: str | None = None, *args, **kwarg class K8sError(Exception): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: super().__init__(message) self.message = message diff --git a/src/ai/backend/agent/fs.py b/src/ai/backend/agent/fs.py index 5452f7c404e..6dc59adb3be 100644 --- a/src/ai/backend/agent/fs.py +++ b/src/ai/backend/agent/fs.py @@ -1,8 +1,9 @@ import asyncio +from pathlib import Path from subprocess import CalledProcessError -async def create_scratch_filesystem(scratch_dir, size) -> None: +async def create_scratch_filesystem(scratch_dir: Path, size: int) -> None: """ Create scratch folder size quota by using tmpfs filesystem. @@ -30,7 +31,7 @@ async def create_scratch_filesystem(scratch_dir, size) -> None: raise CalledProcessError(proc.returncode, cmd) -async def destroy_scratch_filesystem(scratch_dir) -> None: +async def destroy_scratch_filesystem(scratch_dir: Path) -> None: """ Destroy scratch folder size quota by using tmpfs filesystem. diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py index 8d45f63b1ad..7c54b566192 100644 --- a/src/ai/backend/agent/kernel.py +++ b/src/ai/backend/agent/kernel.py @@ -130,7 +130,7 @@ def _dump_json_bytes(obj: Any) -> bytes: class RunEvent(Exception): data: Any - def __init__(self, data=None) -> None: + def __init__(self, data: Any = None) -> None: super().__init__() self.data = data @@ -258,7 +258,7 @@ def __getstate__(self) -> Mapping[str, Any]: del props["clean_event"] return props - def __setstate__(self, props) -> None: + def __setstate__(self, props: MutableMapping[str, Any]) -> None: # Used when a `Kernel` object is loaded from pickle data. if "state" not in props: props["state"] = KernelLifecycleStatus.RUNNING @@ -291,7 +291,7 @@ async def close(self) -> None: # - restoration from running containers is done by computer's classmethod # "restore_from_container" - def release_slots(self, computer_ctxs) -> None: + def release_slots(self, computer_ctxs: Mapping[str, Any]) -> None: """ Release the resource slots occupied by the kernel to the allocation maps. @@ -321,7 +321,7 @@ async def check_status(self) -> dict[str, Any] | None: raise NotImplementedError @abstractmethod - async def get_completions(self, text, opts) -> CodeCompletionResp: + async def get_completions(self, text: str, opts: Mapping[str, Any]) -> CodeCompletionResp: raise NotImplementedError @abstractmethod @@ -333,26 +333,26 @@ async def interrupt_kernel(self) -> dict[str, Any]: raise NotImplementedError @abstractmethod - async def start_service(self, service, opts) -> dict[str, Any]: + async def start_service(self, service: str, opts: Mapping[str, Any]) -> dict[str, Any]: raise NotImplementedError @abstractmethod - async def start_model_service(self, model_service) -> dict[str, Any]: + async def start_model_service(self, model_service: Mapping[str, Any]) -> dict[str, Any]: raise NotImplementedError @abstractmethod - async def shutdown_service(self, service) -> None: + async def shutdown_service(self, service: str) -> None: raise NotImplementedError @abstractmethod - async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: + async def check_duplicate_commit(self, kernel_id: KernelId, subdir: str) -> CommitStatus: raise NotImplementedError @abstractmethod async def commit( self, - kernel_id, - subdir, + kernel_id: KernelId, + subdir: str, *, canonical: str | None = None, filename: str | None = None, @@ -758,7 +758,7 @@ def __getstate__(self) -> Mapping[str, Any]: del props["event_producer"] return props - def __setstate__(self, props) -> None: + def __setstate__(self, props: MutableMapping[str, Any]) -> None: global _zctx self.__dict__.update(props) if _zctx is None: @@ -845,7 +845,7 @@ async def ping_status(self) -> None: except Exception: log.exception("AbstractCodeRunner.ping_status(): unexpected error") - async def feed_batch(self, opts) -> None: + async def feed_batch(self, opts: Mapping[str, Any]) -> None: sock = await self._get_socket_pair() clean_cmd = opts.get("clean", "") if clean_cmd is None: @@ -899,7 +899,9 @@ async def feed_and_get_status(self) -> dict[str, float] | None: except asyncio.CancelledError: return None - async def feed_and_get_completion(self, code_text, opts) -> CodeCompletionResult: + async def feed_and_get_completion( + self, code_text: str, opts: Mapping[str, Any] + ) -> CodeCompletionResult: sock = await self._get_socket_pair() payload = { "code": code_text, @@ -916,7 +918,7 @@ async def feed_and_get_completion(self, code_text, opts) -> CodeCompletionResult except asyncio.CancelledError: return CodeCompletionResult.failure() - async def feed_start_model_service(self, model_info) -> dict[str, Any]: + async def feed_start_model_service(self, model_info: Mapping[str, Any]) -> dict[str, Any]: sock = await self._get_socket_pair() await sock.send_multipart([ b"start-model-service", @@ -938,7 +940,7 @@ async def feed_start_model_service(self, model_info) -> dict[str, Any]: except TimeoutError: return {"status": "failed", "error": "timeout"} - async def feed_start_service(self, service_info) -> dict[str, Any]: + async def feed_start_service(self, service_info: Mapping[str, Any]) -> dict[str, Any]: sock = await self._get_socket_pair() await sock.send_multipart([ b"start-service", @@ -1050,7 +1052,7 @@ def aggregate_console( else: raise AssertionError("Unrecognized API version") - async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult: + async def get_next_result(self, api_ver: int = 2, flush_timeout: float = 2.0) -> NextResult: # Context: per API request has_continuation = ClientFeatures.CONTINUATION in self.client_features records = [] diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index 7db8058ccb1..39b963a7b9a 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -489,7 +489,11 @@ async def mount_vfolders( resource_spec.mounts.append(mount) @override - async def apply_accelerator_allocation(self, computer, device_alloc) -> None: + async def apply_accelerator_allocation( + self, + computer: AbstractComputePlugin, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: # update_nested_dict( # self.computer_docker_args, # await computer.generate_docker_args(self.docker, device_alloc)) @@ -550,7 +554,7 @@ async def prepare_container( self, resource_spec: KernelResourceSpec, environ: Mapping[str, str], - service_ports, + service_ports: Any, cluster_info: ClusterInfo, ) -> KubernetesKernel: loop = current_loop() @@ -687,8 +691,8 @@ async def start_container( self, kernel_obj: AbstractKernel, cmdargs: list[str], - resource_opts, - preopen_ports, + resource_opts: Any, + preopen_ports: Any, cluster_info: ClusterInfo, ) -> Mapping[str, Any]: image_labels = self.kernel_config["image"]["labels"] @@ -908,14 +912,17 @@ async def check_krunner_pv_status(self) -> None: if len(pv.items) == 0: # PV does not exists; create one if self.local_config.container.scratch_type == ScratchType.K8S_NFS: + scratch_nfs_address = self.local_config.container.scratch_nfs_address + if scratch_nfs_address is None: + raise K8sError("scratch_nfs_address must be set when using K8S_NFS") new_pv: NFSPersistentVolume | HostPathPersistentVolume = NFSPersistentVolume( - self.local_config.container.scratch_nfs_address, + scratch_nfs_address, "backend-ai-static-pv", capacity, ) new_pv.label( "backend.ai/backend-ai-scratch-volume", - self.local_config.container.scratch_nfs_address, + scratch_nfs_address, ) scratch_nfs_options = self.local_config.container.scratch_nfs_options if scratch_nfs_options is not None: @@ -949,9 +956,12 @@ async def check_krunner_pv_status(self) -> None: capacity, ) if self.local_config.container.scratch_type == ScratchType.K8S_NFS: + scratch_nfs_address = self.local_config.container.scratch_nfs_address + if scratch_nfs_address is None: + raise K8sError("scratch_nfs_address must be set when using K8S_NFS") new_pvc.label( "backend.ai/backend-ai-scratch-volume", - self.local_config.container.scratch_nfs_address, + scratch_nfs_address, ) else: new_pvc.label("backend.ai/backend-ai-scratch-volume", "hostPath") diff --git a/src/ai/backend/agent/kubernetes/files.py b/src/ai/backend/agent/kubernetes/files.py index 36848e332c4..630f67602c0 100644 --- a/src/ai/backend/agent/kubernetes/files.py +++ b/src/ai/backend/agent/kubernetes/files.py @@ -17,7 +17,7 @@ log.info("Automatic ~/.output file S3 uploads is disabled.") -def relpath(path, base) -> Path: +def relpath(path: str | Path, base: str | Path) -> Path: return Path(path).resolve().relative_to(Path(base).resolve()) @@ -52,7 +52,7 @@ def scandir(root: Path, allowed_max_size: int) -> dict[Path, float]: return file_stats -def diff_file_stats(fs1, fs2) -> set[Path]: +def diff_file_stats(fs1: dict[Path, float], fs2: dict[Path, float]) -> set[Path]: k2 = set(fs2.keys()) k1 = set(fs1.keys()) new_files = k2 - k1 diff --git a/src/ai/backend/agent/kubernetes/intrinsic.py b/src/ai/backend/agent/kubernetes/intrinsic.py index f66fc1799b2..8a9200bcd80 100644 --- a/src/ai/backend/agent/kubernetes/intrinsic.py +++ b/src/ai/backend/agent/kubernetes/intrinsic.py @@ -187,7 +187,7 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: Docker, - device_alloc, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], ) -> Mapping[str, Any]: # This function might be needed later to apply fine-grained tuning for # K8s resource allocation @@ -346,7 +346,7 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: async def generate_docker_args( self, docker: Docker, - device_alloc, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], ) -> Mapping[str, Any]: # This function might be needed later to apply fine-grained tuning for # K8s resource allocation diff --git a/src/ai/backend/agent/kubernetes/kernel.py b/src/ai/backend/agent/kubernetes/kernel.py index 914b5e4c1d1..d3e26292873 100644 --- a/src/ai/backend/agent/kubernetes/kernel.py +++ b/src/ai/backend/agent/kubernetes/kernel.py @@ -226,15 +226,15 @@ async def get_service_apps(self) -> dict[str, Any]: return await self.runner.feed_service_apps() @override - async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: + async def check_duplicate_commit(self, kernel_id: Any, subdir: Any) -> CommitStatus: log.error("Committing in Kubernetes is not supported yet.") raise NotImplementedError @override async def commit( self, - kernel_id, - subdir, + kernel_id: Any, + subdir: Any, *, canonical: str | None = None, filename: str | None = None, @@ -369,15 +369,15 @@ class KubernetesCodeRunner(AbstractCodeRunner): def __init__( self, - kernel_id, - session_id, - event_producer, + kernel_id: Any, + session_id: Any, + event_producer: EventProducer, *, - kernel_host, - repl_in_port, - repl_out_port, - exec_timeout=0, - client_features=None, + kernel_host: str, + repl_in_port: int, + repl_out_port: int, + exec_timeout: int = 0, + client_features: frozenset[str] | None = None, ) -> None: super().__init__( kernel_id, diff --git a/src/ai/backend/agent/kubernetes/kube_object.py b/src/ai/backend/agent/kubernetes/kube_object.py index 8d9d9351ef0..8bf587b0a1e 100644 --- a/src/ai/backend/agent/kubernetes/kube_object.py +++ b/src/ai/backend/agent/kubernetes/kube_object.py @@ -55,7 +55,7 @@ class KubernetesHostPathVolume(KubernetesAbstractVolume): class ConfigMap(AbstractAPIObject): items: dict[str, str] = {} - def __init__(self, kernel_id, name: str) -> None: + def __init__(self, kernel_id: Any, name: str) -> None: self.name = name self.labels = {"backend.ai/kernel-id": kernel_id} @@ -76,7 +76,7 @@ def to_dict(self) -> dict: class Service(AbstractAPIObject): def __init__( - self, kernel_id: str, name: str, container_port: list, service_type="NodePort" + self, kernel_id: str, name: str, container_port: list[Any], service_type: str = "NodePort" ) -> None: self.name = name self.deployment_name = f"kernel-{kernel_id}" @@ -108,7 +108,7 @@ def to_dict(self) -> dict: class NFSPersistentVolume(AbstractAPIObject): - def __init__(self, server, name, capacity, path="/") -> None: + def __init__(self, server: str, name: str, capacity: str, path: str = "/") -> None: self.server = server self.path = path self.name = name @@ -116,7 +116,7 @@ def __init__(self, server, name, capacity, path="/") -> None: self.labels: dict[str, str] = {} self.options: list[str] = [] - def label(self, k, v) -> None: + def label(self, k: str, v: str) -> None: self.labels[k] = v def to_dict(self) -> dict: @@ -142,14 +142,14 @@ def to_dict(self) -> dict: class HostPathPersistentVolume(AbstractAPIObject): - def __init__(self, path, name, capacity) -> None: + def __init__(self, path: str, name: str, capacity: str) -> None: self.path = path self.name = name self.capacity = capacity self.labels: dict[str, str] = {} self.options: list[str] = [] - def label(self, k, v) -> None: + def label(self, k: str, v: str) -> None: self.labels[k] = v def to_dict(self) -> dict: @@ -174,13 +174,13 @@ def to_dict(self) -> dict: class PersistentVolumeClaim(AbstractAPIObject): - def __init__(self, name, pv_name, capacity) -> None: + def __init__(self, name: str, pv_name: str, capacity: str) -> None: self.name = name self.pv_name = pv_name self.capacity = capacity self.labels: dict[str, str] = {} - def label(self, k, v) -> None: + def label(self, k: str, v: str) -> None: self.labels[k] = v def to_dict(self) -> dict: diff --git a/src/ai/backend/agent/kubernetes/resources.py b/src/ai/backend/agent/kubernetes/resources.py index dd9aca2481f..c4a3fc62991 100644 --- a/src/ai/backend/agent/kubernetes/resources.py +++ b/src/ai/backend/agent/kubernetes/resources.py @@ -86,7 +86,7 @@ async def scan_available_resources( return slots -async def get_resource_spec_from_container(container_info) -> Optional[KernelResourceSpec]: +async def get_resource_spec_from_container(container_info: Any) -> Optional[KernelResourceSpec]: for mount in container_info["HostConfig"]["Mounts"]: if mount["Target"] == "/home/config": async with aiofiles.open(Path(mount["Source"]) / "resource.txt") as f: # type: ignore diff --git a/src/ai/backend/agent/resources.py b/src/ai/backend/agent/resources.py index 18c3bda9f3e..599f62dca42 100644 --- a/src/ai/backend/agent/resources.py +++ b/src/ai/backend/agent/resources.py @@ -827,7 +827,7 @@ def discover_plugins( ) -> Iterator[tuple[str, type[AbstractComputePlugin]]]: scanned_plugins = [*super().discover_plugins(plugin_group, allowlist, blocklist)] - def accel_lt_intrinsic(item) -> int: + def accel_lt_intrinsic(item: tuple[str, type[AbstractComputePlugin]]) -> int: # push back "intrinsic" plugins (if exists) if item[0] in ("cpu", "mem"): return 0 @@ -852,22 +852,25 @@ def __str__(self) -> str: return f"{self.source}:{self.target}:{self.permission.value}" @classmethod - def from_str(cls, s) -> Self: - source, target, perm = s.split(":") - source = Path(source) + def from_str(cls, s: str) -> Self: + source_str, target_str, perm_str = s.split(":") + source_path = Path(source_str) type = MountTypes.BIND - if not source.is_absolute(): - if len(source.parts) == 1: - source = str(source) + source: Optional[Path] + if not source_path.is_absolute(): + if len(source_path.parts) == 1: + source = Path(source_str) type = MountTypes.VOLUME else: raise ValueError( - "Mount source must be an absolute path if it is not a volume name.", source + "Mount source must be an absolute path if it is not a volume name.", source_path ) - target = Path(target) + else: + source = source_path + target = Path(target_str) if not target.is_absolute(): raise ValueError("Mount target must be an absolute path.", target) - perm = MountPermission(perm) + perm = MountPermission(perm_str) return cls(type, source, target, perm, None) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 9043147a071..cd25c438b78 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -437,7 +437,7 @@ async def _debug_server_task() -> None: async def status_snapshot_request_handler( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - def _ensure_serializable(o) -> Any: + def _ensure_serializable(o: Any) -> Any: match o: case dict() | defaultdict() | OrderedDict(): return {_ensure_serializable(k): _ensure_serializable(v) for k, v in o.items()} diff --git a/src/ai/backend/agent/vendor/linux.py b/src/ai/backend/agent/vendor/linux.py index 1b4a5710f9c..4fc6fc80354 100644 --- a/src/ai/backend/agent/vendor/linux.py +++ b/src/ai/backend/agent/vendor/linux.py @@ -47,7 +47,7 @@ def parse_cpuset(value: str) -> Iterator[int]: class libnuma: @staticmethod - def node_of_cpu(core) -> int: + def node_of_cpu(core: int) -> int: if _numa_supported: return int(_libnuma.numa_node_of_cpu(core)) # type: ignore return 0 @@ -163,7 +163,7 @@ async def read_os_cpus() -> tuple[set[int], str] | None: log.debug("read the available cpuset from {}", cpuset_source) @staticmethod - async def get_core_topology(limit_cpus=None) -> tuple[list[int], ...]: + async def get_core_topology(limit_cpus: set[int] | None = None) -> tuple[list[int], ...]: topo: tuple[list[int], ...] = tuple([] for _ in range(libnuma.num_nodes())) for c in await libnuma.get_available_cores(): if limit_cpus is not None and c not in limit_cpus: diff --git a/src/ai/backend/agent/watcher/__init__.py b/src/ai/backend/agent/watcher/__init__.py index c31d3f21d88..9ff10f5bb8e 100644 --- a/src/ai/backend/agent/watcher/__init__.py +++ b/src/ai/backend/agent/watcher/__init__.py @@ -5,7 +5,7 @@ import ssl import subprocess import sys -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from http import HTTPStatus from pathlib import Path from pprint import pformat, pprint @@ -34,7 +34,9 @@ @web.middleware -async def auth_middleware(request, handler) -> web.StreamResponse: +async def auth_middleware( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] +) -> web.StreamResponse: token = request.headers.get("X-BackendAI-Watcher-Token", None) if token == request.app["token"]: try: @@ -236,7 +238,7 @@ async def handle_umount(request: web.Request) -> web.Response: return web.Response(text=out) -async def init_app(app) -> None: +async def init_app(app: web.Application) -> None: r = app.router.add_route r("GET", "/", handle_status) if app["config"]["watcher"]["soft-reset-available"]: @@ -252,11 +254,11 @@ async def init_app(app) -> None: r("DELETE", "/mounts", handle_umount) -async def shutdown_app(app) -> None: +async def shutdown_app(app: web.Application) -> None: pass -async def prepare_hook(request, response) -> None: +async def prepare_hook(request: web.Request, response: web.StreamResponse) -> None: response.headers["Server"] = "BackendAI-AgentWatcher" diff --git a/src/ai/backend/appproxy/common/openapi.py b/src/ai/backend/appproxy/common/openapi.py index 74e96346a4b..3fa3ba84b76 100644 --- a/src/ai/backend/appproxy/common/openapi.py +++ b/src/ai/backend/appproxy/common/openapi.py @@ -28,7 +28,7 @@ def get_path_parameters(resource: AbstractResource) -> list[dict]: def generate_openapi( - component: str, subapps: list[web.Application], verbose=False + component: str, subapps: list[web.Application], verbose: bool = False ) -> dict[str, Any]: openapi: dict[str, Any] = { "openapi": "3.1.0", diff --git a/src/ai/backend/appproxy/common/utils.py b/src/ai/backend/appproxy/common/utils.py index 98418812b47..7920651bce6 100644 --- a/src/ai/backend/appproxy/common/utils.py +++ b/src/ai/backend/appproxy/common/utils.py @@ -39,15 +39,15 @@ _danger_words = ["password", "passwd", "secret"] -def set_handler_attr(func, key, value) -> None: +def set_handler_attr(func: Callable[..., Any], key: str, value: Any) -> None: attrs = getattr(func, "_backend_attrs", None) if attrs is None: attrs = {} attrs[key] = value - func._backend_attrs = attrs + func._backend_attrs = attrs # type: ignore[attr-defined] -def get_handler_attr(request, key, default=None) -> Any: +def get_handler_attr(request: web.Request, key: str, default: Any = None) -> Any: # When used in the aiohttp server-side codes, we should use # request.match_info.hanlder instead of handler passed to the middleware # functions because aiohttp wraps this original handler with functools.partial @@ -175,7 +175,7 @@ def ensure_stream_response_type[TAnyResponse: web.StreamResponse]( def pydantic_api_response_handler( handler: THandlerFuncWithoutParam, - is_deprecated=False, + is_deprecated: bool = False, ) -> Handler: """ Only for API handlers which does not require request body. @@ -200,7 +200,7 @@ def pydantic_api_handler[TParamModel: BaseModel, TQueryModel: BaseModel]( checker: type[TParamModel], loads: Callable[[str], Any] | None = None, query_param_checker: type[TQueryModel] | None = None, - is_deprecated=False, + is_deprecated: bool = False, ) -> Callable[[THandlerFuncWithParam], Handler]: def wrap( handler: THandlerFuncWithParam, @@ -280,7 +280,7 @@ def config_key_to_kebab_case(o: Any) -> Any: return o -def mime_match(base_array: str, compare: str, strict=False) -> bool: +def mime_match(base_array: str, compare: str, strict: bool = False) -> bool: """ Checks if `base_array` MIME string contains `compare` MIME type. @@ -307,7 +307,7 @@ class BackendAIAccessLogger(AccessLogger): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def log(self, request, response, time) -> None: + def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: if request.get("do_not_print_access_log"): return diff --git a/src/ai/backend/appproxy/coordinator/cli/__main__.py b/src/ai/backend/appproxy/coordinator/cli/__main__.py index 745c3cc9467..d11bb885094 100644 --- a/src/ai/backend/appproxy/coordinator/cli/__main__.py +++ b/src/ai/backend/appproxy/coordinator/cli/__main__.py @@ -168,7 +168,9 @@ def generate_openapi_spec(output: Path) -> None: ) @click.argument("psql_args", nargs=-1, type=click.UNPROCESSED) @click.pass_obj -def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args) -> None: +def dbshell( + cli_ctx: CLIContext, container_name: str | None, psql_help: bool, psql_args: list[str] +) -> None: """ Run the database shell. diff --git a/src/ai/backend/appproxy/coordinator/cli/dbschema.py b/src/ai/backend/appproxy/coordinator/cli/dbschema.py index 7745fcac4a3..008324e95c7 100644 --- a/src/ai/backend/appproxy/coordinator/cli/dbschema.py +++ b/src/ai/backend/appproxy/coordinator/cli/dbschema.py @@ -2,7 +2,7 @@ import asyncio import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import click @@ -16,7 +16,7 @@ @click.group() -def cli(args) -> None: +def cli(args: Any) -> None: pass @@ -29,7 +29,7 @@ def cli(args) -> None: help="The path to Alembic config file. [default: alembic-appproxy.ini]", ) @click.pass_obj -def show(cli_ctx: CLIContext, alembic_config) -> None: +def show(cli_ctx: CLIContext, alembic_config: str) -> None: """Show the current schema information.""" from alembic.config import Config from alembic.runtime.migration import MigrationContext diff --git a/src/ai/backend/appproxy/coordinator/models/base.py b/src/ai/backend/appproxy/coordinator/models/base.py index 65186c6f9a7..21579e9cc18 100644 --- a/src/ai/backend/appproxy/coordinator/models/base.py +++ b/src/ai/backend/appproxy/coordinator/models/base.py @@ -49,7 +49,7 @@ class BaseMixin: - def dump_model(self, serializable=True) -> dict[str, Any]: + def dump_model(self, serializable: bool = True) -> dict[str, Any]: o = dict(self.__dict__) del o["_sa_instance_state"] if serializable: @@ -68,7 +68,7 @@ def dump_model(self, serializable=True) -> dict[str, Any]: # helper functions -def zero_if_none(val) -> int: +def zero_if_none(val: int | None) -> int: return 0 if val is None else val @@ -84,7 +84,7 @@ class EnumType(TypeDecorator, SchemaType): # type: ignore impl = ENUM cache_ok = True - def __init__(self, enum_cls, **opts) -> None: + def __init__(self, enum_cls: type[enum.Enum], **opts: Any) -> None: if not issubclass(enum_cls, enum.Enum): raise InvalidEnumTypeError(f"Expected an Enum subclass, got {enum_cls}") if "name" not in opts: @@ -94,10 +94,10 @@ def __init__(self, enum_cls, **opts) -> None: super().__init__(*enums, **opts) self._enum_cls = enum_cls - def process_bind_param(self, value, dialect) -> Optional[str]: + def process_bind_param(self, value: enum.Enum | None, dialect: sa.Dialect) -> Optional[str]: return value.name if value else None - def process_result_value(self, value: Any, dialect) -> Optional[enum.Enum]: + def process_result_value(self, value: Any, dialect: sa.Dialect) -> Optional[enum.Enum]: return self._enum_cls[value] if value else None def copy(self, **kw: Any) -> EnumType: @@ -165,12 +165,12 @@ def __init__(self, schema: type[BaseModel]) -> None: super().__init__() self._schema = schema - def load_dialect_impl(self, dialect) -> TypeEngine[Any]: + def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "sqlite": - return dialect.type_descriptor(sa.JSON) + return dialect.type_descriptor(sa.JSON()) # type: ignore[arg-type] return super().load_dialect_impl(dialect) - def process_bind_param(self, value, dialect) -> BaseModel: + def process_bind_param(self, value: Any, dialect: sa.Dialect) -> BaseModel: if value is None: return self._schema() try: @@ -182,7 +182,7 @@ def process_bind_param(self, value, dialect) -> BaseModel: ) from e return value - def process_result_value(self, value, dialect) -> BaseModel: + def process_result_value(self, value: Any, dialect: sa.Dialect) -> BaseModel: if value is None: return self._schema() return self._schema(**value) @@ -203,12 +203,12 @@ def __init__(self, schema: type[BaseModel]) -> None: super().__init__() self._schema = schema - def process_bind_param(self, value: BaseModel | None, dialect) -> Optional[str]: + def process_bind_param(self, value: BaseModel | None, dialect: sa.Dialect) -> Optional[str]: if value: return value.model_dump_json() return None - def process_result_value(self, value: str | None, dialect) -> Optional[BaseModel]: + def process_result_value(self, value: str | None, dialect: sa.Dialect) -> Optional[BaseModel]: if value: return self._schema(**json.loads(value)) return None @@ -233,12 +233,16 @@ def __init__(self, schema: type[TBaseModel]) -> None: super().__init__() self._schema = schema - def process_bind_param(self, value: list[TBaseModel] | None, dialect) -> Optional[str]: + def process_bind_param( + self, value: list[TBaseModel] | None, dialect: sa.Dialect + ) -> Optional[str]: if value is not None: return TypeAdapter(list[TBaseModel]).dump_json(value).decode("utf-8") return None - def process_result_value(self, value: str | None, dialect) -> Optional[list[TBaseModel]]: + def process_result_value( + self, value: str | None, dialect: sa.Dialect + ) -> Optional[list[TBaseModel]]: if value is not None: return [self._schema(**i) for i in json.loads(value)] return None @@ -255,12 +259,14 @@ class URLColumn(TypeDecorator): impl = sa.types.UnicodeText cache_ok = True - def process_bind_param(self, value, dialect) -> Optional[str]: + def process_bind_param( + self, value: yarl.URL | str | None, dialect: sa.Dialect + ) -> Optional[str]: if isinstance(value, yarl.URL): return str(value) return value - def process_result_value(self, value, dialect) -> Optional[yarl.URL]: + def process_result_value(self, value: str | None, dialect: sa.Dialect) -> Optional[yarl.URL]: if value is None: return None if value is not None: @@ -276,7 +282,7 @@ class IPColumn(TypeDecorator): impl = CIDR cache_ok = True - def process_bind_param(self, value, dialect) -> Optional[str]: + def process_bind_param(self, value: str | None, dialect: sa.Dialect) -> Optional[str]: if value is None: return value try: @@ -285,7 +291,9 @@ def process_bind_param(self, value, dialect) -> Optional[str]: raise InvalidAPIParameters(f"{value} is invalid IP address value") from e return cidr - def process_result_value(self, value, dialect) -> Optional[ReadableCIDR]: + def process_result_value( + self, value: str | None, dialect: sa.Dialect + ) -> Optional[ReadableCIDR]: if value is None: return None return ReadableCIDR(value) @@ -304,12 +312,12 @@ class GUID[UUID_SubType: uuid.UUID](TypeDecorator): uuid_subtype_func: ClassVar[Callable[[Any], Any]] = lambda v: v cache_ok = True - def load_dialect_impl(self, dialect) -> TypeEngine[Any]: + def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) return dialect.type_descriptor(CHAR(16)) - def process_bind_param(self, value: Any, dialect) -> Optional[str | bytes]: + def process_bind_param(self, value: Any, dialect: sa.Dialect) -> Optional[str | bytes]: # NOTE: EndpointId, SessionId, KernelId are *not* actual types defined as classes, # but a "virtual" type that is an identity function at runtime. # The type checker treats them as distinct derivatives of uuid.UUID. @@ -324,7 +332,7 @@ def process_bind_param(self, value: Any, dialect) -> Optional[str | bytes]: return value.bytes return uuid.UUID(value).bytes - def process_result_value(self, value: Any, dialect) -> Optional[UUID_SubType]: + def process_result_value(self, value: Any, dialect: sa.Dialect) -> Optional[UUID_SubType]: if value is None: return value cls = type(self) @@ -337,9 +345,9 @@ def process_result_value(self, value: Any, dialect) -> Optional[UUID_SubType]: return cast(UUID_SubType, cls.uuid_subtype_func(uuid.UUID(value))) -def IDColumn(name="id") -> sa.Column: +def IDColumn(name: str = "id") -> sa.Column: return sa.Column(name, GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) -def ForeignKeyIDColumn(name, fk_field, nullable=True) -> sa.Column: +def ForeignKeyIDColumn(name: str, fk_field: str, nullable: bool = True) -> sa.Column: return sa.Column(name, GUID, sa.ForeignKey(fk_field), nullable=nullable) diff --git a/src/ai/backend/appproxy/coordinator/models/circuit.py b/src/ai/backend/appproxy/coordinator/models/circuit.py index d98d6e1f049..61ad7b3cc61 100644 --- a/src/ai/backend/appproxy/coordinator/models/circuit.py +++ b/src/ai/backend/appproxy/coordinator/models/circuit.py @@ -115,8 +115,8 @@ async def get( cls, session: AsyncSession, circuit_id: UUID, - load_worker=True, - load_endpoint=True, + load_worker: bool = True, + load_endpoint: bool = True, ) -> "Circuit": query = sa.select(Circuit).where(Circuit.id == circuit_id) if load_worker: @@ -133,8 +133,8 @@ async def get_by_endpoint( cls, session: AsyncSession, endpoint_id: UUID, - load_worker=True, - load_endpoint=True, + load_worker: bool = True, + load_endpoint: bool = True, ) -> "Circuit": query = sa.select(Circuit).where(Circuit.endpoint_id == endpoint_id) if load_worker: @@ -150,8 +150,8 @@ async def get_by_endpoint( async def list_circuits( cls, session: AsyncSession, - load_worker=True, - load_endpoint=True, + load_worker: bool = True, + load_endpoint: bool = True, ) -> Sequence["Circuit"]: query = sa.select(Circuit) if load_worker: @@ -190,8 +190,8 @@ async def find_by_endpoint( cls, session: AsyncSession, endpoint_id: UUID, - load_worker=True, - load_endpoint=True, + load_worker: bool = True, + load_endpoint: bool = True, ) -> "Circuit": query = sa.select(Circuit).where(Circuit.endpoint_id == endpoint_id) if load_worker: diff --git a/src/ai/backend/appproxy/coordinator/models/worker.py b/src/ai/backend/appproxy/coordinator/models/worker.py index 0a9c8ddefff..37e93b56e26 100644 --- a/src/ai/backend/appproxy/coordinator/models/worker.py +++ b/src/ai/backend/appproxy/coordinator/models/worker.py @@ -135,8 +135,8 @@ async def get( cls, session: AsyncSession, worker_id: UUID, - load_filters=False, - load_circuits=False, + load_filters: bool = False, + load_circuits: bool = False, ) -> "Worker": query = sa.select(Worker).filter(Worker.id == worker_id) if load_filters: @@ -153,8 +153,8 @@ async def find_by_authority( cls, session: AsyncSession, authority: str, - load_filters=False, - load_circuits=False, + load_filters: bool = False, + load_circuits: bool = False, ) -> "Worker": query = sa.select(Worker).filter(Worker.authority == authority) if load_filters: @@ -170,8 +170,8 @@ async def find_by_authority( async def list_workers( cls, session: AsyncSession, - load_filters=False, - load_circuits=False, + load_filters: bool = False, + load_circuits: bool = False, ) -> list["Worker"]: query = sa.select(Worker) if load_filters: @@ -198,8 +198,8 @@ def create( wildcard_domain: str | None = None, wildcard_traffic_port: int | None = None, traefik_last_used_marker_path: str | None = None, - filtered_apps_only=False, - status=WorkerStatus.LOST, + filtered_apps_only: bool = False, + status: WorkerStatus = WorkerStatus.LOST, ) -> "Worker": w = cls() w.id = id @@ -408,7 +408,7 @@ async def add_circuit( *, envs: dict[str, Any] | None = None, args: str | None = None, - open_to_public=False, + open_to_public: bool = False, allowed_client_ips: str | None = None, preferred_port: int | None = None, preferred_subdomain: str | None = None, diff --git a/src/ai/backend/appproxy/coordinator/server.py b/src/ai/backend/appproxy/coordinator/server.py index 5be589c52e3..e7f98abdf08 100644 --- a/src/ai/backend/appproxy/coordinator/server.py +++ b/src/ai/backend/appproxy/coordinator/server.py @@ -291,8 +291,8 @@ async def _make_message_queue( node_id: str, redis_config: RedisConfig, *, - anycast_stream_key=APPPROXY_ANYCAST_STREAM_KEY, - broadcast_channel=APPPROXY_BROADCAST_CHANNEL, + anycast_stream_key: str = APPPROXY_ANYCAST_STREAM_KEY, + broadcast_channel: str = APPPROXY_BROADCAST_CHANNEL, use_experimental_redis_event_dispatcher: bool = False, ) -> AbstractMessageQueue: redis_profile_target: RedisProfileTarget = RedisProfileTarget.from_dict(redis_config.to_dict()) @@ -1154,7 +1154,7 @@ async def server_main_logwrapper( help="Set the logging verbosity level", ) @click.pass_context -def main(ctx: click.Context, config_path: Path, debug: bool, log_level: LogLevel) -> None: +def main(ctx: click.Context, config_path: Path | None, debug: bool, log_level: LogLevel) -> None: """ Start the proxy-coordinator service as a foreground process. """ diff --git a/src/ai/backend/appproxy/worker/proxy/backend/http.py b/src/ai/backend/appproxy/worker/proxy/backend/http.py index e6b4471c59c..25162640f6d 100644 --- a/src/ai/backend/appproxy/worker/proxy/backend/http.py +++ b/src/ai/backend/appproxy/worker/proxy/backend/http.py @@ -237,7 +237,7 @@ async def proxy_ws(self, request: web.Request) -> web.WebSocketResponse: async def _proxy_task( left: web.WebSocketResponse | aiohttp.ClientWebSocketResponse, right: web.WebSocketResponse | aiohttp.ClientWebSocketResponse, - tag="(unknown)", + tag: str = "(unknown)", ) -> None: nonlocal total_bytes diff --git a/src/ai/backend/appproxy/worker/proxy/backend/tcp.py b/src/ai/backend/appproxy/worker/proxy/backend/tcp.py index 570171e424b..0d259eb62c7 100644 --- a/src/ai/backend/appproxy/worker/proxy/backend/tcp.py +++ b/src/ai/backend/appproxy/worker/proxy/backend/tcp.py @@ -47,7 +47,7 @@ async def bind( async def _pipe( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - tag="(unknown)", + tag: str = "(unknown)", ) -> None: nonlocal total_bytes diff --git a/src/ai/backend/appproxy/worker/proxy/frontend/h2/base.py b/src/ai/backend/appproxy/worker/proxy/frontend/h2/base.py index f3d6ace086d..31d5fde4b09 100644 --- a/src/ai/backend/appproxy/worker/proxy/frontend/h2/base.py +++ b/src/ai/backend/appproxy/worker/proxy/frontend/h2/base.py @@ -26,7 +26,9 @@ async def list_inactive_circuits(self, threshold: int) -> list[Circuit]: # We can't measure activeness of HTTP/2 circuits return [] - async def _log_monitor_task(self, stream: asyncio.StreamReader, log_header_postfix="") -> None: + async def _log_monitor_task( + self, stream: asyncio.StreamReader, log_header_postfix: str = "" + ) -> None: while True: line = await stream.readline() if len(line) == 0: diff --git a/src/ai/backend/cli/extensions.py b/src/ai/backend/cli/extensions.py index fb25d235047..d64f4a6d840 100644 --- a/src/ai/backend/cli/extensions.py +++ b/src/ai/backend/cli/extensions.py @@ -1,6 +1,7 @@ import os import signal import sys +from collections.abc import Callable from typing import Any, Optional import click @@ -85,7 +86,7 @@ def command(self, *args, **kwargs) -> Any: if not aliases: return decorator - def _decorator(f) -> click.Command: + def _decorator(f: Callable) -> click.Command: cmd = decorator(f) if aliases: self._commands[cmd.name] = aliases @@ -104,7 +105,7 @@ def group(self, *args, **kwargs) -> Any: if not aliases: return decorator - def _decorator(f) -> click.Group: + def _decorator(f: Callable) -> click.Group: cmd = decorator(f) if aliases: self._commands[cmd.name] = aliases @@ -114,7 +115,7 @@ def _decorator(f) -> click.Group: return _decorator - def get_command(self, ctx, cmd_name) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: if cmd_name in self._aliases: cmd_name = self._aliases[cmd_name] command = super().get_command(ctx, cmd_name) @@ -122,7 +123,7 @@ def get_command(self, ctx, cmd_name) -> Optional[click.Command]: return command return None - def format_commands(self, ctx, formatter) -> None: + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: commands = [] for subcommand in self.list_commands(ctx): cmd = self.get_command(ctx, subcommand) diff --git a/src/ai/backend/cli/interaction.py b/src/ai/backend/cli/interaction.py index 622643bb50f..2681dfb7fd2 100644 --- a/src/ai/backend/cli/interaction.py +++ b/src/ai/backend/cli/interaction.py @@ -6,7 +6,7 @@ Numeric = int | float -def ask_host(prompt: str, default: str = "127.0.0.1", allow_hostname=False) -> str: +def ask_host(prompt: str, default: str = "127.0.0.1", allow_hostname: bool = False) -> str: while True: user_reply = input(f"{prompt} (default: {default}): ") if user_reply == "": @@ -92,7 +92,7 @@ def ask_string_in_array(prompt: str, choices: list, default: str) -> str | None: return user_reply -def ask_path(prompt: str, is_file=True, is_directory=True) -> Path: +def ask_path(prompt: str, is_file: bool = True, is_directory: bool = True) -> Path: if not (is_file or is_directory): print("One of args(is_file/is_directory) has True value.") while True: diff --git a/src/ai/backend/cli/params.py b/src/ai/backend/cli/params.py index cf3f9aad447..020fe6a8c9c 100644 --- a/src/ai/backend/cli/params.py +++ b/src/ai/backend/cli/params.py @@ -226,7 +226,9 @@ def __init__(self, type_: Optional[type[TScalar]] = None) -> None: super().__init__() self.type_ = type_ if type_ is not None else str - def convert(self, arg, param, ctx) -> int | list[Any]: + def convert( + self, arg: str | int, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> int | list[Any]: try: match arg: case int(): diff --git a/src/ai/backend/client/cli/admin/group.py b/src/ai/backend/client/cli/admin/group.py index 2e29e7adc92..b2aa8b85ef7 100644 --- a/src/ai/backend/client/cli/admin/group.py +++ b/src/ai/backend/client/cli/admin/group.py @@ -69,7 +69,7 @@ def info(ctx: CLIContext, id_or_name: str) -> None: @click.option( "-d", "--domain-name", type=str, default=None, help="Domain name to list groups belongs to it." ) -def list(ctx: CLIContext, domain_name) -> None: +def list(ctx: CLIContext, domain_name: str | None) -> None: """ List groups in the given domain. (admin privilege required) diff --git a/src/ai/backend/client/cli/admin/keypair.py b/src/ai/backend/client/cli/admin/keypair.py index 2adfa617f87..abeaf4c58f5 100644 --- a/src/ai/backend/client/cli/admin/keypair.py +++ b/src/ai/backend/client/cli/admin/keypair.py @@ -68,7 +68,15 @@ def info(ctx: CLIContext) -> None: @click.option("--order", default=None, help="Set the query ordering expression.") @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list(ctx: CLIContext, user_id, is_active, filter_, order, offset, limit) -> None: +def list( + ctx: CLIContext, + user_id: str | None, + is_active: bool | None, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, +) -> None: """ List keypairs. To show all keypairs or other user's, your access key must have the admin diff --git a/src/ai/backend/client/cli/admin/manager.py b/src/ai/backend/client/cli/admin/manager.py index 3e7eef09144..e3e31505b45 100644 --- a/src/ai/backend/client/cli/admin/manager.py +++ b/src/ai/backend/client/cli/admin/manager.py @@ -76,7 +76,7 @@ def freeze(wait: bool, force_kill: bool) -> None: if force_kill: print_wait("Killing all sessions...") - session.Manager.freeze(force_kill=force_kill) + _ = session.Manager.freeze(force_kill=force_kill) if force_kill: print_done("All sessions are killed.") @@ -94,7 +94,7 @@ def unfreeze() -> None: try: with Session() as session: - session.Manager.unfreeze() + _ = session.Manager.unfreeze() print("Manager is successfully unfrozen.") except Exception as e: print_error(e) @@ -143,7 +143,7 @@ def update(message: str | None) -> None: if message is None: print_info("Cancelled") sys.exit(ExitCode.FAILURE) - session.Manager.update_announcement(enabled=True, message=message) + _ = session.Manager.update_announcement(enabled=True, message=message) print_done("Posted new announcement.") except Exception as e: print_error(e) @@ -160,7 +160,7 @@ def delete() -> None: sys.exit(ExitCode.FAILURE) try: with Session() as session: - session.Manager.update_announcement(enabled=False) + _ = session.Manager.update_announcement(enabled=False) print_done("Deleted announcement.") except Exception as e: print_error(e) @@ -208,7 +208,7 @@ def include_agents(agent_ids: tuple[str, ...]) -> None: try: with Session() as session: - session.Manager.scheduler_op("include-agents", agent_ids) + _ = session.Manager.scheduler_op("include-agents", agent_ids) print_done("The given agents now accepts new sessions.") except Exception as e: print_error(e) @@ -227,7 +227,7 @@ def exclude_agents(agent_ids: tuple[str, ...]) -> None: try: with Session() as session: - session.Manager.scheduler_op("exclude-agents", agent_ids) + _ = session.Manager.scheduler_op("exclude-agents", agent_ids) print_done("The given agents will no longer start new sessions.") except Exception as e: print_error(e) diff --git a/src/ai/backend/client/cli/admin/quota_scope.py b/src/ai/backend/client/cli/admin/quota_scope.py index ebab18f7c91..9bafd577f21 100644 --- a/src/ai/backend/client/cli/admin/quota_scope.py +++ b/src/ai/backend/client/cli/admin/quota_scope.py @@ -194,7 +194,7 @@ def set_( if qsid is None: ctx.output.print_fail("Identifier is not valid") sys.exit(ExitCode.INVALID_ARGUMENT) - session.QuotaScope.set_quota_scope( + _ = session.QuotaScope.set_quota_scope( host=host, qsid=qsid, config=QuotaConfig(limit_bytes=limit_bytes), @@ -251,7 +251,7 @@ def unset( if qsid is None: ctx.output.print_fail("Identifier is not valid") sys.exit(ExitCode.INVALID_ARGUMENT) - session.QuotaScope.unset_quota_scope( + _ = session.QuotaScope.unset_quota_scope( host=host, qsid=qsid, ) diff --git a/src/ai/backend/client/cli/admin/resource.py b/src/ai/backend/client/cli/admin/resource.py index b1cfd8c7c9a..4e9defe3b74 100644 --- a/src/ai/backend/client/cli/admin/resource.py +++ b/src/ai/backend/client/cli/admin/resource.py @@ -74,7 +74,7 @@ def recalculate_usage() -> None: with Session() as session: try: - session.Resource.recalculate_usage() + _ = session.Resource.recalculate_usage() print("Resource allocation is re-calculated.") except Exception as e: print_error(e) diff --git a/src/ai/backend/client/cli/admin/scaling_group.py b/src/ai/backend/client/cli/admin/scaling_group.py index 6a68e5b5430..96007950b23 100644 --- a/src/ai/backend/client/cli/admin/scaling_group.py +++ b/src/ai/backend/client/cli/admin/scaling_group.py @@ -293,7 +293,7 @@ def update( @scaling_group.command() @pass_ctx_obj @click.argument("name", type=str, metavar="NAME") -def delete(ctx: CLIContext, name) -> None: +def delete(ctx: CLIContext, name: str) -> None: """ Delete an existing scaling group. @@ -328,7 +328,7 @@ def delete(ctx: CLIContext, name) -> None: @pass_ctx_obj @click.argument("scaling_group", type=str, metavar="SCALING_GROUP") @click.argument("domain", type=str, metavar="DOMAIN") -def associate_scaling_group(ctx: CLIContext, scaling_group, domain) -> None: +def associate_scaling_group(ctx: CLIContext, scaling_group: str, domain: str) -> None: """ Associate a domain with a scaling_group. @@ -365,7 +365,7 @@ def associate_scaling_group(ctx: CLIContext, scaling_group, domain) -> None: @pass_ctx_obj @click.argument("scaling_group", type=str, metavar="SCALING_GROUP") @click.argument("domain", type=str, metavar="DOMAIN") -def dissociate_scaling_group(ctx: CLIContext, scaling_group, domain) -> None: +def dissociate_scaling_group(ctx: CLIContext, scaling_group: str, domain: str) -> None: """ Dissociate a domain from a scaling_group. diff --git a/src/ai/backend/client/cli/admin/storage.py b/src/ai/backend/client/cli/admin/storage.py index 713a8e00a93..8c0af7775d7 100644 --- a/src/ai/backend/client/cli/admin/storage.py +++ b/src/ai/backend/client/cli/admin/storage.py @@ -55,7 +55,9 @@ def info(ctx: CLIContext, vfolder_host: str) -> None: @click.option("--order", default=None, help="Set the query ordering expression.") @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list(ctx: CLIContext, filter_, order, offset, limit) -> None: +def list( + ctx: CLIContext, filter_: str | None, order: str | None, offset: int, limit: int | None +) -> None: """ List storage volumes. (super-admin privilege required) diff --git a/src/ai/backend/client/cli/admin/user.py b/src/ai/backend/client/cli/admin/user.py index 7bd7502a813..ec0f24c465b 100644 --- a/src/ai/backend/client/cli/admin/user.py +++ b/src/ai/backend/client/cli/admin/user.py @@ -132,7 +132,15 @@ def info(ctx: CLIContext, email: str) -> None: ) @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None: +def list( + ctx: CLIContext, + status: str | None, + group: str | None, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, +) -> None: """ List users. (admin privilege required) @@ -585,7 +593,7 @@ def validate_input[T_Input](value: T_Input, null_flag: bool, field_name: str) -> @user.command() @pass_ctx_obj @click.argument("email", type=str, metavar="EMAIL") -def delete(ctx: CLIContext, email) -> None: +def delete(ctx: CLIContext, email: str) -> None: """ Inactivate an existing user. @@ -640,7 +648,9 @@ def delete(ctx: CLIContext, email) -> None: "and delegate the ownership to the requested admin." ), ) -def purge(ctx: CLIContext, email, purge_shared_vfolders, delegate_endpoint_ownership) -> None: +def purge( + ctx: CLIContext, email: str, purge_shared_vfolders: bool, delegate_endpoint_ownership: bool +) -> None: """ Delete an existing user. This action cannot be undone. diff --git a/src/ai/backend/client/cli/admin/vfolder.py b/src/ai/backend/client/cli/admin/vfolder.py index 952902da052..a401e5850d0 100644 --- a/src/ai/backend/client/cli/admin/vfolder.py +++ b/src/ai/backend/client/cli/admin/vfolder.py @@ -96,7 +96,14 @@ def _list_cmd(docs: Optional[str] = None) -> Callable[..., None]: ) @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") - def list(ctx: CLIContext, group, filter_, order, offset, limit) -> None: + def list( + ctx: CLIContext, + group: str | None, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, + ) -> None: """ List virtual folders. """ @@ -151,7 +158,7 @@ def list_hosts() -> None: @vfolder.command() @click.argument("vfolder_host") -def perf_metric(vfolder_host) -> None: +def perf_metric(vfolder_host: str) -> None: """ Show the performance statistics of a vfolder host. (superadmin privilege required) @@ -182,7 +189,7 @@ def perf_metric(vfolder_host) -> None: @click.option( "-a", "--agent-id", type=str, default=None, help="Target agent to fetch fstab contents." ) -def get_fstab_contents(agent_id) -> None: +def get_fstab_contents(agent_id: str | None) -> None: """ Get contents of fstab file from a node. (superadmin privilege required) @@ -229,7 +236,7 @@ def list_mounts() -> None: @click.argument("name", type=str) @click.option("-o", "--options", type=str, default=None, help="Mount options.") @click.option("--edit-fstab", is_flag=True, help="Edit fstab file to mount permanently.") -def mount_host(fs_location, name, options, edit_fstab) -> None: +def mount_host(fs_location: str, name: str, options: str | None, edit_fstab: bool) -> None: """ Mount a host in virtual folder root. (superadmin privilege required) @@ -259,7 +266,7 @@ def mount_host(fs_location, name, options, edit_fstab) -> None: @vfolder.command() @click.argument("name", type=str) @click.option("--edit-fstab", is_flag=True, help="Edit fstab file to mount permanently.") -def umount_host(name, edit_fstab) -> None: +def umount_host(name: str, edit_fstab: bool) -> None: """ Unmount a host from virtual folder root. (superadmin privilege required) @@ -319,7 +326,7 @@ def list_shared_vfolders() -> None: @vfolder.command @click.argument("vfolder_id", type=str) -def shared_vfolder_info(vfolder_id) -> None: +def shared_vfolder_info(vfolder_id: str) -> None: """Show the vfolder permission information of the given virtual folder. \b @@ -358,7 +365,7 @@ def shared_vfolder_info(vfolder_id) -> None: @click.option( "-p", "--permission", type=str, metavar="PERMISSION", help="Folder's innate permission." ) -def update_shared_vf_permission(vfolder_id, user_id, permission) -> None: +def update_shared_vf_permission(vfolder_id: str, user_id: str, permission: str) -> None: """ Update permission for shared vfolders. @@ -382,7 +389,7 @@ def update_shared_vf_permission(vfolder_id, user_id, permission) -> None: @vfolder.command() @click.argument("vfolder_id", type=str) @click.argument("user_id", type=str) -def remove_shared_vf_permission(vfolder_id, user_id) -> None: +def remove_shared_vf_permission(vfolder_id: str, user_id: str) -> None: """ Remove permission for shared vfolders. @@ -405,7 +412,7 @@ def remove_shared_vf_permission(vfolder_id, user_id) -> None: @vfolder.command() @click.argument("vfolder_id", type=str) @click.argument("user_email", type=str) -def change_vfolder_ownership(vfolder_id, user_email) -> None: +def change_vfolder_ownership(vfolder_id: str, user_email: str) -> None: """ Change the ownership of vfolder @@ -417,7 +424,7 @@ def change_vfolder_ownership(vfolder_id, user_email) -> None: with Session() as session: try: - session.VFolder.change_vfolder_ownership(vfolder_id, user_email) + _ = session.VFolder.change_vfolder_ownership(vfolder_id, user_email) print(f"Now ownership of VFolder:{vfolder_id} goes to User:{user_email}") except Exception as e: print_error(e) diff --git a/src/ai/backend/client/cli/app.py b/src/ai/backend/client/cli/app.py index 21df1ed5eb8..f367ab8c09d 100644 --- a/src/ai/backend/client/cli/app.py +++ b/src/ai/backend/client/cli/app.py @@ -3,7 +3,7 @@ import shlex import sys from collections.abc import MutableMapping, Sequence -from typing import Optional +from typing import Any, Optional import aiohttp import click @@ -293,7 +293,7 @@ async def __aexit__(self, *exc_info) -> None: metavar='"ENVNAME=envvalue"', help="Add additional environment variable when starting service.", ) -def app(session_name, app, bind, arg, env) -> None: +def app(session_name: str, app: str, bind: str, arg: tuple[str, ...], env: tuple[str, ...]) -> None: """ Run a local proxy to a service provided by Backend.AI compute sessions. @@ -331,7 +331,7 @@ def app(session_name, app, bind, arg, env) -> None: @click.argument("session_name", type=str, metavar="SESSION_ID", nargs=1) @click.argument("app_name", type=str, metavar="APP", nargs=-1) @click.option("-l", "--list-names", is_flag=True, help="Just print all available services.") -def apps(session_name, app_name, list_names) -> None: +def apps(session_name: str, app_name: tuple[str, ...], list_names: bool) -> None: """ List available additional arguments and environment variables when starting service. @@ -342,7 +342,7 @@ def apps(session_name, app_name, list_names) -> None: """ async def print_arguments() -> None: - apps = [] + apps: list[dict[str, Any]] = [] async with AsyncSession() as api_session: compute_session = api_session.ComputeSession(session_name) apps = await compute_session.stream_app_info() diff --git a/src/ai/backend/client/cli/config.py b/src/ai/backend/client/cli/config.py index f2cd4c7735b..01ae544663a 100644 --- a/src/ai/backend/client/cli/config.py +++ b/src/ai/backend/client/cli/config.py @@ -162,7 +162,7 @@ def logout() -> None: with Session() as session: try: - session.Auth.logout() + _ = session.Auth.logout() print_done("Logout done.") try: (local_state_path / "cookie.dat").unlink() @@ -177,7 +177,7 @@ def logout() -> None: @click.argument("old_password", metavar="OLD_PASSWORD") @click.argument("new_password", metavar="NEW_PASSWORD") @click.argument("new_password2", metavar="NEW_PASSWORD2") -def update_password(old_password, new_password, new_password2) -> None: +def update_password(old_password: str, new_password: str, new_password2: str) -> None: """ Update user's password. """ @@ -188,7 +188,7 @@ def update_password(old_password, new_password, new_password2) -> None: with Session() as session: try: - session.Auth.update_password(old_password, new_password, new_password2) + _ = session.Auth.update_password(old_password, new_password, new_password2) print_done("Password updated.") except Exception as e: print_error(e) @@ -199,7 +199,9 @@ def update_password(old_password, new_password, new_password2) -> None: @click.argument("user_id", metavar="USER_ID") @click.argument("current_password", metavar="CURRENT_PASSWORD") @click.argument("new_password", metavar="NEW_PASSWORD") -def update_password_no_auth(domain, user_id, current_password, new_password) -> None: +def update_password_no_auth( + domain: str, user_id: str, current_password: str, new_password: str +) -> None: """ Update user's password. This is used to update `EXPIRED` password only. """ @@ -207,11 +209,11 @@ def update_password_no_auth(domain, user_id, current_password, new_password) -> try: config = get_config() if config.endpoint_type == "session": - session.Auth.update_password_no_auth_in_session( + _ = session.Auth.update_password_no_auth_in_session( user_id, current_password, new_password ) else: - session.Auth.update_password_no_auth( + _ = session.Auth.update_password_no_auth( domain, user_id, current_password, new_password ) print_done("Password updated.") diff --git a/src/ai/backend/client/cli/dotfile.py b/src/ai/backend/client/cli/dotfile.py index 1defecb67e6..415351cd389 100644 --- a/src/ai/backend/client/cli/dotfile.py +++ b/src/ai/backend/client/cli/dotfile.py @@ -47,7 +47,14 @@ def dotfile() -> None: "(If group name is provided, domain name must be specified with option -d)" ), ) -def create(path, permission, dotfile_path, owner_access_key, domain, group) -> None: +def create( + path: str, + permission: str | None, + dotfile_path: str | None, + owner_access_key: str | None, + domain: str | None, + group: str | None, +) -> None: """ Store dotfile to Backend.AI Manager. Dotfiles will be automatically loaded when creating kernels. @@ -102,7 +109,7 @@ def create(path, permission, dotfile_path, owner_access_key, domain, group) -> N "(If group name is provided, domain name must be specified with option -d)" ), ) -def get(path, owner_access_key, domain, group) -> None: +def get(path: str, owner_access_key: str | None, domain: str | None, group: str | None) -> None: """ Print dotfile content. """ @@ -139,7 +146,7 @@ def get(path, owner_access_key, domain, group) -> None: "(If group name is provided, domain name must be specified with option -d)" ), ) -def list(owner_access_key, domain, group) -> None: +def list(owner_access_key: str | None, domain: str | None, group: str | None) -> None: """ List available user/domain/group dotfiles. """ @@ -203,7 +210,14 @@ def list(owner_access_key, domain, group) -> None: "(If group name is provided, domain name must be specified with option -d)" ), ) -def update(path, permission, dotfile_path, owner_access_key, domain, group) -> None: +def update( + path: str, + permission: str | None, + dotfile_path: str | None, + owner_access_key: str | None, + domain: str | None, + group: str | None, +) -> None: """ Update dotfile stored in Backend.AI Manager. """ @@ -222,7 +236,7 @@ def update(path, permission, dotfile_path, owner_access_key, domain, group) -> N dotfile_ = session.Dotfile( path, owner_access_key=owner_access_key, domain=domain, group=group ) - dotfile_.update(body, permission) + _ = dotfile_.update(body, permission) print_info(f"Dotfile {dotfile_.path} updated") except Exception as e: print_error(e) @@ -252,7 +266,9 @@ def update(path, permission, dotfile_path, owner_access_key, domain, group) -> N "(If group name is provided, domain name must be specified with option -d)" ), ) -def delete(path, force, owner_access_key, domain, group) -> None: +def delete( + path: str, force: bool, owner_access_key: str | None, domain: str | None, group: str | None +) -> None: """ Delete dotfile from Backend.AI Manager. """ @@ -267,7 +283,7 @@ def delete(path, force, owner_access_key, domain, group) -> None: print_info("Aborting.") exit() try: - dotfile_.delete() + _ = dotfile_.delete() print_info(f"Dotfile {dotfile_.path} deleted") except Exception as e: print_error(e) diff --git a/src/ai/backend/client/cli/image.py b/src/ai/backend/client/cli/image.py index b52d931e1cb..56d622b28e9 100644 --- a/src/ai/backend/client/cli/image.py +++ b/src/ai/backend/client/cli/image.py @@ -26,7 +26,7 @@ def get_image_id( architecture: str | None = None, ) -> str: try: - session.Image.get_by_id(name_or_id, fields=[image_fields["id"]]) + _ = session.Image.get_by_id(name_or_id, fields=[image_fields["id"]]) return name_or_id except Exception: image = session.Image.get(name_or_id, architecture, fields=[image_fields["id"]]) @@ -66,7 +66,7 @@ def list(ctx: CLIContext, customized: bool) -> None: @image.command() @click.argument("reference_or_id", type=str) @click.option("--arch", type=str, default=None, help="Set an explicit architecture.") -def forget(reference_or_id, arch) -> None: +def forget(reference_or_id: str, arch: str | None) -> None: """Mark image as deleted from server. This command will only work for image customized by user unless callee has superadmin privileges. @@ -112,7 +112,7 @@ def purge(reference_or_id: str, arch: str, remove_from_registry: bool) -> None: ) sys.exit(ExitCode.FAILURE) try: - session.Image.purge_image_by_id(image_id, remove_from_registry=remove_from_registry) + _ = session.Image.purge_image_by_id(image_id, remove_from_registry=remove_from_registry) except Exception as e: print_error(e) sys.exit(ExitCode.FAILURE) diff --git a/src/ai/backend/client/cli/logs.py b/src/ai/backend/client/cli/logs.py index 3f8836ddc84..83cd0f604e9 100644 --- a/src/ai/backend/client/cli/logs.py +++ b/src/ai/backend/client/cli/logs.py @@ -12,7 +12,7 @@ @main.command() @click.argument("task_id", metavar="TASKID") -def task_logs(task_id) -> None: +def task_logs(task_id: str) -> None: """ Shows the output logs of a batch task. diff --git a/src/ai/backend/client/cli/model.py b/src/ai/backend/client/cli/model.py index 872cc31f50d..70e55c9b280 100644 --- a/src/ai/backend/client/cli/model.py +++ b/src/ai/backend/client/cli/model.py @@ -32,7 +32,9 @@ def model() -> None: @click.option("--order", default=None, help="Set the query ordering expression.") @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list(ctx: CLIContext, filter_, order, offset, limit) -> None: +def list( + ctx: CLIContext, filter_: str | None, order: str | None, offset: int, limit: int | None +) -> None: """ List the models. """ @@ -58,7 +60,7 @@ def list(ctx: CLIContext, filter_, order, offset, limit) -> None: @model.command() @pass_ctx_obj @click.argument("model_name", metavar="MODEL", type=str) -def info(ctx: CLIContext, model_name) -> None: +def info(ctx: CLIContext, model_name: str) -> None: """ Display the detail of a model with its backing storage vfolder. @@ -141,7 +143,16 @@ def info(ctx: CLIContext, model_name) -> None: is_flag=True, help="Allows the virtual folder to be cloned by users.", ) -def create(ctx: CLIContext, name, host, group, host_path, permission, quota, cloneable) -> None: +def create( + ctx: CLIContext, + name: str, + host: str, + group: str | None, + host_path: bool, + permission: str, + quota: str, + cloneable: bool, +) -> None: """ Create a new model with the given configuration. @@ -180,7 +191,7 @@ def create(ctx: CLIContext, name, host, group, host_path, permission, quota, clo @model.command() @pass_ctx_obj @click.argument("model_name", metavar="MODEL", type=str) -def rm(ctx: CLIContext, model_name) -> None: +def rm(ctx: CLIContext, model_name: str) -> None: """ Remove the given model. @@ -191,7 +202,7 @@ def rm(ctx: CLIContext, model_name) -> None: with Session() as session: try: serving = session.Model(model_name) - serving.delete() + _ = serving.delete() print_done("Model deleted.") except Exception as e: ctx.output.print_error(e) @@ -236,12 +247,12 @@ def rm(ctx: CLIContext, model_name) -> None: ) def upload( ctx: CLIContext, - model_name, - filenames, - model_version, - base_dir, - chunk_size, - override_storage_proxy, + model_name: str, + filenames: tuple[Path, ...], + model_version: str, + base_dir: Path | None, + chunk_size: int, + override_storage_proxy: dict[str, str] | None, ) -> None: """ Upload a file to the model as the given version. @@ -253,7 +264,7 @@ def upload( """ with Session() as session: try: - session.VFolder(model_name).upload( + _ = session.VFolder(model_name).upload( filenames, dst_dir=Path("versions", model_version), basedir=base_dir, @@ -313,13 +324,13 @@ def upload( ) def download( ctx: CLIContext, - model_name, - filenames, - model_version, - base_dir, - chunk_size, - override_storage_proxy, - max_retries, + model_name: str, + filenames: tuple[Path, ...], + model_version: str, + base_dir: Path | None, + chunk_size: int, + override_storage_proxy: dict[str, str] | None, + max_retries: int, ) -> None: """ Download a file from the model storage. @@ -332,7 +343,7 @@ def download( """ with Session() as session: try: - session.VFolder(model_name).download( + _ = session.VFolder(model_name).download( filenames, dst_dir=Path("versions", model_version), basedir=base_dir, diff --git a/src/ai/backend/client/cli/network.py b/src/ai/backend/client/cli/network.py index b206cc3f807..128a59a8963 100644 --- a/src/ai/backend/client/cli/network.py +++ b/src/ai/backend/client/cli/network.py @@ -32,7 +32,7 @@ def network() -> None: @click.argument("project", type=str, metavar="PROJECT_ID_OR_NAME") @click.argument("name", type=str, metavar="NAME") @click.option("-d", "--driver", default=None, help="Set the network driver.") -def create(ctx: CLIContext, project, name, driver) -> None: +def create(ctx: CLIContext, project: str, name: str, driver: str | None) -> None: """Create a new network interface.""" with Session() as session: @@ -73,7 +73,14 @@ def create(ctx: CLIContext, project, name, driver) -> None: @click.option("--order", default=None, help="Set the query ordering expression.") @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list(ctx: CLIContext, format, filter_, order, offset, limit) -> None: +def list( + ctx: CLIContext, + format: str | None, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, +) -> None: """List all available network interfaces.""" if format: @@ -112,7 +119,7 @@ def list(ctx: CLIContext, format, filter_, order, offset, limit) -> None: default=None, help="Display only specified fields. When specifying multiple fields separate them with comma (,).", ) -def get(ctx: CLIContext, network, format) -> None: +def get(ctx: CLIContext, network: str, format: str | None) -> None: fields: Iterable[Any] if format: try: @@ -144,7 +151,7 @@ def get(ctx: CLIContext, network, format) -> None: @network.command() @pass_ctx_obj @click.argument("network", type=str, metavar="NETWORK_ID_OR_NAME") -def delete(ctx: CLIContext, network) -> None: +def delete(ctx: CLIContext, network: str) -> None: with Session() as session: try: network_info = session.Network(uuid.UUID(network)).get(fields=[network_fields["id"]]) @@ -163,7 +170,7 @@ def delete(ctx: CLIContext, network) -> None: network_info = networks.items[0] try: - session.Network(uuid.UUID(network_info["row_id"])).delete() + _ = session.Network(uuid.UUID(network_info["row_id"])).delete() print_done(f"Network {network} has been deleted.") except BackendAPIError as e: ctx.output.print_fail(f"Failed to delete network {network}:") diff --git a/src/ai/backend/client/cli/pretty.py b/src/ai/backend/client/cli/pretty.py index 7fea0e4c807..0cc3e164ba6 100644 --- a/src/ai/backend/client/cli/pretty.py +++ b/src/ai/backend/client/cli/pretty.py @@ -9,7 +9,7 @@ import traceback from collections.abc import Iterator, Sequence from types import TracebackType -from typing import Optional, Self +from typing import Any, Optional, Self, TextIO from click import echo, style from tqdm import tqdm @@ -79,7 +79,9 @@ def format_pretty(msg: str, status: PrintStatus = PrintStatus.NONE, colored: boo format_warn = functools.partial(format_pretty, status=PrintStatus.WARNING) -def print_pretty(msg, *, status=PrintStatus.NONE, file=None) -> None: +def print_pretty( + msg: str, *, status: PrintStatus = PrintStatus.NONE, file: TextIO | None = None +) -> None: if file is None: file = sys.stderr if status == PrintStatus.NONE: @@ -182,7 +184,7 @@ def format_error(exc: Exception) -> Iterator[str]: yield ("*** Traceback ***\n" + "".join(traceback.format_tb(exc.__traceback__)).strip()) -def print_error(exc: Exception, *, file=None) -> None: +def print_error(exc: Exception, *, file: TextIO | None = None) -> None: if file is None: file = sys.stderr indicator = style("\u2718", fg="bright_red", reset=False) @@ -196,7 +198,14 @@ def print_error(exc: Exception, *, file=None) -> None: file.flush() -def show_warning(message, category, filename, lineno, file=None, line=None) -> None: +def show_warning( + message: str, + category: type[Warning], + filename: str, + lineno: int, + file: TextIO | None = None, + line: str | None = None, +) -> None: echo( "{}: {}".format( style(str(category.__name__), fg="yellow", bold=True), @@ -224,19 +233,19 @@ class ProgressBarWithSpinner(tqdm): @staticmethod def alt_format_meter( - n, - total, - elapsed, - ncols=None, - prefix="", - ascii=False, - unit="it", - unit_scale=False, - rate=None, - bar_format=None, - postfix=None, - *args, - **kwargs, + n: int | float, + total: int | float | None, + elapsed: float, + ncols: int | None = None, + prefix: str = "", + ascii: bool = False, + unit: str = "it", + unit_scale: bool = False, + rate: float | None = None, + bar_format: str | None = None, + postfix: str | None = None, + *args: Any, + **kwargs: Any, ) -> str: # Return the prefix string only. return str(prefix) + str(postfix) diff --git a/src/ai/backend/client/cli/proxy.py b/src/ai/backend/client/cli/proxy.py index 1e7b3b30e11..5bf6d9ce86a 100644 --- a/src/ai/backend/client/cli/proxy.py +++ b/src/ai/backend/client/cli/proxy.py @@ -116,7 +116,7 @@ def _translate_headers(upstream_request: Request, client_request: Request) -> No upstream_request.headers["Host"] = f"{api_endpoint.host}:{api_endpoint.port}" -async def web_handler(request) -> web.StreamResponse: +async def web_handler(request: web.Request) -> web.StreamResponse: path = re.sub(r"^/?v(\d+)/", "/", request.path) try: # We treat all requests and responses as streaming universally @@ -167,7 +167,7 @@ async def web_handler(request) -> web.StreamResponse: ) -async def websocket_handler(request) -> web.WebSocketResponse | web.Response: +async def websocket_handler(request: web.Request) -> web.WebSocketResponse | web.Response: path = re.sub(r"^/?v(\d+)/", "/", request.path) try: api_request = Request( @@ -235,7 +235,7 @@ def create_proxy_app() -> web.Application: help="The TCP port to accept non-encrypted non-authorized API requests.", ) @click.pass_context -def proxy(ctx, bind, port) -> None: +def proxy(ctx: click.Context, bind: str, port: int) -> None: """ Run a non-encrypted non-authorized API proxy server. Use this only for development and testing! diff --git a/src/ai/backend/client/cli/server_log.py b/src/ai/backend/client/cli/server_log.py index 3d0955151e6..793575ce4cb 100644 --- a/src/ai/backend/client/cli/server_log.py +++ b/src/ai/backend/client/cli/server_log.py @@ -22,7 +22,7 @@ def server_logs() -> None: "-l", "--page-size", type=int, default=20, help="Number of logs to fetch (from latest log)" ) @click.option("-n", "--page-number", type=int, default=1, help="Page number to fetch.") -def list(mark_read, page_size, page_number) -> None: +def list(mark_read: bool, page_size: int, page_number: int) -> None: """Fetch server (error) logs.""" with Session() as session: try: diff --git a/src/ai/backend/client/cli/service.py b/src/ai/backend/client/cli/service.py index 3b24f8d1c55..99cedc760d6 100644 --- a/src/ai/backend/client/cli/service.py +++ b/src/ai/backend/client/cli/service.py @@ -47,7 +47,7 @@ def get_service_id(session: Session, name_or_id: str) -> UUID: try: - session.Service(name_or_id).info() + _ = session.Service(name_or_id).info() except (ValueError, BackendError): services = session.Service.list(name=name_or_id) try: @@ -601,7 +601,7 @@ def rm(ctx: CLIContext, service_name_or_id: str) -> None: with Session() as session: try: service_id = get_service_id(session, service_name_or_id) - session.Service(service_id).delete() + _ = session.Service(service_id).delete() print_done("Removed.") except Exception as e: ctx.output.print_error(e) @@ -621,7 +621,7 @@ def sync(ctx: CLIContext, service_name_or_id: str) -> None: with Session() as session: try: service_id = get_service_id(session, service_name_or_id) - session.Service(service_id).sync() + _ = session.Service(service_id).sync() print_done("Done.") except Exception as e: ctx.output.print_error(e) @@ -647,7 +647,7 @@ def scale( with Session() as session: try: service_id = get_service_id(session, service_name_or_id) - session.Service(service_id).scale(target_count) + _ = session.Service(service_id).scale(target_count) print_done("Triggered scaling.") except Exception as e: ctx.output.print_error(e) @@ -728,7 +728,7 @@ def update_traffic_ratio( with Session() as session: try: service_id = get_service_id(session, service_name_or_id) - session.Service(service_id).update_traffic_ratio(route_id, ratio) + _ = session.Service(service_id).update_traffic_ratio(route_id, ratio) print_done("Done.") except Exception as e: ctx.output.print_error(e) @@ -752,7 +752,7 @@ def downscale_route(ctx: CLIContext, service_name_or_id: str, route_id: UUID) -> with Session() as session: try: service_id = get_service_id(session, service_name_or_id) - session.Service(service_id).downscale_single_route(route_id) + _ = session.Service(service_id).downscale_single_route(route_id) print_done("Done.") except Exception as e: ctx.output.print_error(e) diff --git a/src/ai/backend/client/cli/service_auto_scaling_rule.py b/src/ai/backend/client/cli/service_auto_scaling_rule.py index 764918be193..78e015bc16c 100644 --- a/src/ai/backend/client/cli/service_auto_scaling_rule.py +++ b/src/ai/backend/client/cli/service_auto_scaling_rule.py @@ -236,8 +236,8 @@ def update( try: rule_instance = session.ServiceAutoScalingRule(rule) - rule_instance.get() - rule_instance.update( + _ = rule_instance.get() + _ = rule_instance.update( metric_source=metric_source, metric_name=metric_name, threshold=_threshold, @@ -261,8 +261,8 @@ def delete(ctx: CLIContext, rule: UUID) -> None: with Session() as session: rule_instance = session.ServiceAutoScalingRule(rule) try: - rule_instance.get(fields=[service_auto_scaling_rule_fields["id"]]) - rule_instance.delete() + _ = rule_instance.get(fields=[service_auto_scaling_rule_fields["id"]]) + _ = rule_instance.delete() print_done(f"Autosscaling rule {rule_instance.rule_id} has been deleted.") except BackendAPIError as e: ctx.output.print_fail(f"Failed to delete rule {rule_instance.rule_id}:") diff --git a/src/ai/backend/client/cli/session/app.py b/src/ai/backend/client/cli/session/app.py index 90ce1c71cfb..1380df8c072 100644 --- a/src/ai/backend/client/cli/session/app.py +++ b/src/ai/backend/client/cli/session/app.py @@ -3,7 +3,7 @@ import shlex import sys from collections.abc import MutableMapping, Sequence -from typing import Optional +from typing import Any, Optional import aiohttp import click @@ -291,7 +291,7 @@ async def __aexit__(self, *exc_info) -> None: metavar='"ENVNAME=envvalue"', help="Add additional environment variable when starting service.", ) -def app(session_name, app, bind, arg, env) -> None: +def app(session_name: str, app: str, bind: str, arg: tuple[str, ...], env: tuple[str, ...]) -> None: """ Run a local proxy to a service provided by Backend.AI compute sessions. @@ -329,7 +329,7 @@ def app(session_name, app, bind, arg, env) -> None: @click.argument("session_name", type=str, metavar="NAME", nargs=1) @click.argument("app_name", type=str, metavar="APP", nargs=-1) @click.option("-l", "--list-names", is_flag=True, help="Just print all available services.") -def apps(session_name, app_name, list_names) -> None: +def apps(session_name: str, app_name: tuple[str, ...], list_names: bool) -> None: """ List available additional arguments and environment variables when starting service. @@ -340,7 +340,7 @@ def apps(session_name, app_name, list_names) -> None: """ async def print_arguments() -> None: - apps = [] + apps: list[dict[str, Any]] = [] async with AsyncSession() as api_session: compute_session = api_session.ComputeSession(session_name) apps = await compute_session.stream_app_info() diff --git a/src/ai/backend/client/cli/session/execute.py b/src/ai/backend/client/cli/session/execute.py index 7ad802758cc..e70f92ccb07 100644 --- a/src/ai/backend/client/cli/session/execute.py +++ b/src/ai/backend/client/cli/session/execute.py @@ -10,9 +10,9 @@ import sys import traceback import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from decimal import Decimal -from typing import Optional +from typing import TYPE_CHECKING, Any, Optional, TextIO import aiohttp import click @@ -37,6 +37,9 @@ from ai.backend.client.exceptions import BackendError from ai.backend.client.output.fields import network_fields from ai.backend.client.session import AsyncSession + +if TYPE_CHECKING: + from ai.backend.client.func.session import ComputeSession from ai.backend.common.arch import DEFAULT_IMAGE_ARCH from ai.backend.common.types import ClusterMode, MountExpression @@ -48,15 +51,15 @@ async def exec_loop( - stdout, - stderr, - compute_session, - mode, - code, + stdout: TextIO, + stderr: TextIO, + compute_session: ComputeSession, + mode: str, + code: str, *, - opts=None, - vprint_done=print_done, - is_multi=False, + opts: dict[str, Any] | None = None, + vprint_done: Callable[[str], None] = print_done, + is_multi: bool = False, ) -> None: """ Fully streamed asynchronous version of the execute loop. @@ -114,7 +117,14 @@ async def exec_loop( def exec_loop_sync( - stdout, stderr, compute_session, mode, code, *, opts=None, vprint_done=print_done + stdout: TextIO, + stderr: TextIO, + compute_session: ComputeSession, + mode: str, + code: str, + *, + opts: dict[str, Any] | None = None, + vprint_done: Callable[[str], None] = print_done, ) -> None: """ Old synchronous polling version of the execute loop. @@ -143,17 +153,17 @@ def exec_loop_sync( print("--- end of generated files ---", file=stdout) if result["status"] == "clean-finished": exitCode = result.get("exitCode") - vprint_done(f"Clean finished. (exit code = {exitCode}", file=stdout) + vprint_done(f"Clean finished. (exit code = {exitCode})") mode = "continue" code = "" elif result["status"] == "build-finished": exitCode = result.get("exitCode") - vprint_done(f"Build finished. (exit code = {exitCode})", file=stdout) + vprint_done(f"Build finished. (exit code = {exitCode})") mode = "continue" code = "" elif result["status"] == "finished": exitCode = result.get("exitCode") - vprint_done(f"Execution finished. (exit code = {exitCode})", file=stdout) + vprint_done(f"Execution finished. (exit code = {exitCode})") break elif result["status"] == "waiting-input": mode = "input" @@ -166,16 +176,21 @@ def exec_loop_sync( code = "" -async def exec_terminal(compute_session, *, vprint_wait=print_wait, vprint_done=print_done) -> None: +async def exec_terminal( + compute_session: ComputeSession, + *, + vprint_wait: Callable[[str], None] = print_wait, + vprint_done: Callable[[str], None] = print_done, +) -> None: # async with compute_session.stream_pty() as stream: ... raise NotImplementedError -def _noop(*args, **kwargs) -> None: +def _noop(*args: Any, **kwargs: Any) -> None: pass -def format_stats(stats) -> str: +def format_stats(stats: dict[str, Any]) -> str: formatted = [] version = stats.pop("version", 1) stats.pop("status") @@ -394,44 +409,44 @@ def prepare_mount_arg( ) @click_start_option() def run( - image, - files, - name, # click_start_option - type, # click_start_option + image: str, + files: tuple[str, ...], + name: str | None, # click_start_option + type: str, # click_start_option priority: int | None, # click_start_option - starts_at, # click_start_option - enqueue_only, # click_start_option - max_wait, # click_start_option - no_reuse, # click_start_option - callback_url, # click_start_option - code, - terminal, # query-mode options - clean, - build, - exec, - basedir, # batch-mode options - env, # click_start_option - bootstrap_script, - rm, - stats, - tag, # click_start_option - quiet, # extra options - env_range, - build_range, - exec_range, - max_parallel, # experiment support - mount, # click_start_option - scaling_group, # click_start_option - resources, # click_start_option - cluster_size, # click_start_option + starts_at: str | None, # click_start_option + enqueue_only: bool, # click_start_option + max_wait: int, # click_start_option + no_reuse: bool, # click_start_option + callback_url: str | None, # click_start_option + code: str | None, + terminal: bool, # query-mode options + clean: str | None, + build: str | None, + exec: str | None, + basedir: str | None, # batch-mode options + env: tuple[str, ...], # click_start_option + bootstrap_script: TextIO | None, + rm: bool, + stats: bool, + tag: str | None, # click_start_option + quiet: bool, # extra options + env_range: tuple[str, ...], + build_range: tuple[str, ...], + exec_range: tuple[str, ...], + max_parallel: int, # experiment support + mount: tuple[str, ...], # click_start_option + scaling_group: str | None, # click_start_option + resources: tuple[str, ...], # click_start_option + cluster_size: int, # click_start_option cluster_mode: ClusterMode, - resource_opts, # click_start_option - architecture, - domain, # click_start_option - group, # click_start_option - network, # click_start_option - preopen, - assign_agent, # resource grouping + resource_opts: tuple[str, ...], # click_start_option + architecture: str, + domain: str | None, # click_start_option + group: str | None, # click_start_option + network: str | None, # click_start_option + preopen: list[str] | None, + assign_agent: list[str] | None, # resource grouping ) -> None: """ Run the given code snippet or files in a session. @@ -472,9 +487,9 @@ def run( if exec_range is None: exec_range = [] - env_ranges = {v: r for v, r in env_range} - build_ranges = {v: r for v, r in build_range} - exec_ranges = {v: r for v, r in exec_range} + env_ranges: dict[str, Any] = {v: r for v, r in env_range} # type: ignore[has-type] + build_ranges: dict[str, Any] = {v: r for v, r in build_range} # type: ignore[has-type] + exec_ranges: dict[str, Any] = {v: r for v, r in exec_range} # type: ignore[has-type] env_var_maps = [ dict(zip(env_ranges.keys(), values, strict=True)) @@ -528,7 +543,15 @@ def run( pretty_env = " ".join(f"{item[0]}={item[1]}" for item in case[0]) print(f"env = {pretty_env!r}, build = {case[1]!r}, exec = {case[2]!r}") - def _run_legacy(session, idx, name, envs, clean_cmd, build_cmd, exec_cmd) -> None: + def _run_legacy( + session: AsyncSession, + idx: int, + name: str, + envs: Mapping[str, str], + clean_cmd: str | None, + build_cmd: str | None, + exec_cmd: str | None, + ) -> None: try: compute_session = session.ComputeSession.get_or_create( image, @@ -628,7 +651,14 @@ def _run_legacy(session, idx, name, envs, clean_cmd, build_cmd, exec_cmd) -> Non print(f"[{idx}] Statistics is not available.") async def _run( - session, idx, name, envs, clean_cmd, build_cmd, exec_cmd, is_multi=False + session: AsyncSession, + idx: int, + name: str, + envs: Mapping[str, str], + clean_cmd: str | None, + build_cmd: str | None, + exec_cmd: str | None, + is_multi: bool = False, ) -> None: try: if network: @@ -719,7 +749,7 @@ async def _run( try: - def indexed_vprint_done(msg) -> None: + def indexed_vprint_done(msg: str) -> None: vprint_done(f"[{idx}] " + msg) if files: diff --git a/src/ai/backend/client/cli/session/lifecycle.py b/src/ai/backend/client/cli/session/lifecycle.py index 44b9c42756e..b260983709c 100644 --- a/src/ai/backend/client/cli/session/lifecycle.py +++ b/src/ai/backend/client/cli/session/lifecycle.py @@ -10,7 +10,7 @@ from datetime import UTC, datetime, timedelta from graphlib import TopologicalSorter from pathlib import Path -from typing import IO, Literal, Optional, cast +from typing import IO, Any, Literal, Optional, cast from uuid import UUID import click @@ -622,7 +622,13 @@ def _destroy_cmd(docs: Optional[str] = None) -> Callable[..., None]: @click.option( "-r", "--recursive", is_flag=True, help="Cancel all the dependant sessions recursively" ) - def destroy(session_names, forced, owner, stats, recursive) -> None: + def destroy( + session_names: tuple[str, ...], + forced: bool, + owner: str | None, + stats: bool, + recursive: bool, + ) -> None: """ Terminate and destroy the given session. @@ -686,7 +692,7 @@ def _restart_cmd(docs: Optional[str] = None) -> Callable[..., None]: metavar="ACCESS_KEY", help="Specify the owner of the target session explicitly.", ) - def restart(session_refs, owner) -> None: + def restart(session_refs: tuple[str, ...], owner: str | None) -> None: """ Restart the compute session. @@ -702,7 +708,7 @@ def restart(session_refs, owner) -> None: for session_ref in session_refs: try: compute_session = session.ComputeSession(session_ref, owner) - compute_session.restart() + _ = compute_session.restart() except BackendAPIError as e: print_error(e) if e.status == 404: @@ -732,7 +738,7 @@ def restart(session_refs, owner) -> None: @session.command() @click.argument("session_id", metavar="SESSID") @click.argument("files", type=click.Path(exists=True), nargs=-1) -def upload(session_id, files) -> None: +def upload(session_id: str, files: tuple[str, ...]) -> None: """ Upload the files to a compute session's home directory. If the target directory is in a storage folder mount, the operation is @@ -753,7 +759,7 @@ def upload(session_id, files) -> None: try: print_wait("Uploading files...") kernel = session.ComputeSession(session_id) - kernel.upload(files, show_progress=True) + _ = kernel.upload(files, show_progress=True) print_done("Uploaded.") except Exception as e: print_error(e) @@ -764,7 +770,7 @@ def upload(session_id, files) -> None: @click.argument("session_id", metavar="SESSID") @click.argument("files", nargs=-1) @click.option("--dest", type=Path, default=".", help="Destination path to store downloaded file(s)") -def download(session_id, files, dest) -> None: +def download(session_id: str, files: tuple[str, ...], dest: Path) -> None: """ Download files from a compute session's home directory. If the source path is in a storage folder mount, the operation is @@ -785,7 +791,7 @@ def download(session_id, files, dest) -> None: try: print_wait(f"Downloading file(s) from {session_id}...") kernel = session.ComputeSession(session_id) - kernel.download(files, dest, show_progress=True) + _ = kernel.download(files, dest, show_progress=True) print_done(f"Downloaded to {dest.resolve()}.") except Exception as e: print_error(e) @@ -795,7 +801,7 @@ def download(session_id, files, dest) -> None: @session.command() @click.argument("session_id", metavar="SESSID") @click.argument("path", metavar="PATH", nargs=1, default="/home/work") -def ls(session_id, path) -> None: +def ls(session_id: str, path: str) -> None: """ List files in a path of a running compute session. @@ -1004,7 +1010,7 @@ def convert_to_image(session_id: str, image_name: str) -> None: print_error(e) sys.exit(ExitCode.FAILURE) - async def export_tracker(bgtask_id) -> None: + async def export_tracker(bgtask_id: str) -> None: async with AsyncSession() as session: completion_msg_func = lambda: print_done("Session export process completed.") try: @@ -1246,7 +1252,7 @@ def _events_cmd(docs: Optional[str] = None) -> Callable[..., None]: ) def events( session_id_or_name: str, - owner_access_key: str, + owner_access_key: str | None, scope: str, quiet: bool, wait: Optional[str] = None, @@ -1263,7 +1269,7 @@ def events( - Any other event name waits for that specific event (exit code 0) """ - def print_event(ev) -> None: + def print_event(ev: Any) -> None: click.echo( click.style(ev.event, fg="cyan", bold=True) + " " @@ -1373,7 +1379,11 @@ def _watch_cmd(docs: Optional[str] = None) -> Callable[..., None]: help="Set the output style of the command results.", ) def watch( - session_name_or_id: str, owner_access_key: str, scope: str, max_wait: int, output: str + session_name_or_id: tuple[str, ...], + owner_access_key: str | None, + scope: str, + max_wait: int, + output: str, ) -> None: """ Monitor the lifecycle events of a compute session diff --git a/src/ai/backend/client/cli/session_template.py b/src/ai/backend/client/cli/session_template.py index ac3bc80a610..ba129d53dff 100644 --- a/src/ai/backend/client/cli/session_template.py +++ b/src/ai/backend/client/cli/session_template.py @@ -52,7 +52,9 @@ def session_template() -> None: metavar="ACCESS_KEY", help="Set the owner of the target session explicitly.", ) -def create(template_path, domain, group, owner_access_key) -> None: +def create( + template_path: str | None, domain: str | None, group: str | None, owner_access_key: str | None +) -> None: """ Store task template to Backend.AI Manager and return template ID. Template can be used when creating new session. @@ -94,7 +96,7 @@ def create(template_path, domain, group, owner_access_key) -> None: metavar="ACCESS_KEY", help="Specify the owner of the target session explicitly.", ) -def get(template_id, template_format, owner_access_key) -> None: +def get(template_id: str, template_format: str, owner_access_key: str | None) -> None: """ Print task template associated with given template ID """ @@ -115,7 +117,7 @@ def get(template_id, template_format, owner_access_key) -> None: is_flag=True, help="List all virtual folders (superadmin privilege is required).", ) -def list(list_all) -> None: +def list(list_all: bool) -> None: """ List all available task templates by user. """ @@ -160,7 +162,7 @@ def list(list_all) -> None: metavar="ACCESS_KEY", help="Specify the owner of the target session explicitly.", ) -def update(template_id, template_path, owner_access_key) -> None: +def update(template_id: str, template_path: str | None, owner_access_key: str | None) -> None: """ Update task template stored in Backend.AI Manager. """ @@ -175,7 +177,7 @@ def update(template_id, template_path, owner_access_key) -> None: with Session() as session: try: template = session.SessionTemplate(template_id, owner_access_key=owner_access_key) - template.put(body) + _ = template.put(body) print_info(f"Task template {template.template_id} updated") except Exception as e: print_error(e) @@ -198,7 +200,7 @@ def update(template_id, template_path, owner_access_key) -> None: metavar="ACCESS_KEY", help="Specify the owner of the target session explicitly.", ) -def delete(template_id, force, owner_access_key) -> None: +def delete(template_id: str, force: bool, owner_access_key: str | None) -> None: """ Delete task template from Backend.AI Manager. """ @@ -211,7 +213,7 @@ def delete(template_id, force, owner_access_key) -> None: print_info("Aborting.") exit() try: - template.delete() + _ = template.delete() print_info(f"Task template {template.template_id} deleted") except Exception as e: print_error(e) diff --git a/src/ai/backend/client/cli/vfolder.py b/src/ai/backend/client/cli/vfolder.py index d75cea05a23..7cd840cc9bf 100644 --- a/src/ai/backend/client/cli/vfolder.py +++ b/src/ai/backend/client/cli/vfolder.py @@ -127,7 +127,16 @@ def list_allowed_types() -> None: is_flag=True, help="Allows the virtual folder to be cloned by users.", ) -def create(name, host, group, host_path, unmanaged_path, usage_mode, permission, cloneable) -> None: +def create( + name: str, + host: str, + group: str | None, + host_path: bool, + unmanaged_path: str | None, + usage_mode: str, + permission: str, + cloneable: bool, +) -> None: """Create a new virtual folder. \b @@ -153,7 +162,7 @@ def create(name, host, group, host_path, unmanaged_path, usage_mode, permission, @vfolder.command() @click.argument("name", type=str) -def delete(name) -> None: +def delete(name: str) -> None: """Delete the given virtual folder. The virtual folder will be under `delete-pending` status, which means trash-bin. This operation can be retracted by calling `restore()`. @@ -163,7 +172,7 @@ def delete(name) -> None: """ with Session() as session: try: - session.VFolder(name).delete() + _ = session.VFolder(name).delete() print_done("Deleted.") except Exception as e: print_error(e) @@ -172,14 +181,14 @@ def delete(name) -> None: @vfolder.command() @click.argument("name", type=str) -def purge(name) -> None: +def purge(name: str) -> None: """Purge the given virtual folder. This operation is irreversible! NAME: Name of a virtual folder. """ with Session() as session: try: - session.VFolder(name).purge() + _ = session.VFolder(name).purge() print_done("Purged.") except Exception as e: print_error(e) @@ -188,7 +197,7 @@ def purge(name) -> None: @vfolder.command() @click.argument("name", type=str) -def delete_trash(name) -> None: +def delete_trash(name: str) -> None: """Delete the given virtual folder's real data. The virtual folder should be under `delete-pending` status, which means trash-bin. This operation is irreversible! @@ -196,7 +205,7 @@ def delete_trash(name) -> None: """ with Session() as session: try: - session.VFolder(name).delete_trash() + _ = session.VFolder(name).delete_trash() print_done("Delete completed.") except Exception as e: print_error(e) @@ -205,14 +214,14 @@ def delete_trash(name) -> None: @vfolder.command() @click.argument("name", type=str) -def recover(name) -> None: +def recover(name: str) -> None: """Restore the given virtual folder from deleted status, Deprecated since 24.03.1; use `restore` NAME: Name of a virtual folder. """ with Session() as session: try: - session.VFolder(name).restore() + _ = session.VFolder(name).restore() print_done("Restored.") except Exception as e: print_error(e) @@ -221,14 +230,14 @@ def recover(name) -> None: @vfolder.command() @click.argument("name", type=str) -def restore(name) -> None: +def restore(name: str) -> None: """Restore the given virtual folder from deleted status, from trash bin. NAME: Name of a virtual folder. """ with Session() as session: try: - session.VFolder(name).restore() + _ = session.VFolder(name).restore() print_done("Restored.") except Exception as e: print_error(e) @@ -238,7 +247,7 @@ def restore(name) -> None: @vfolder.command() @click.argument("old_name", type=str) @click.argument("new_name", type=str) -def rename(old_name, new_name) -> None: +def rename(old_name: str, new_name: str) -> None: """Rename the given virtual folder. This operation is irreversible! You cannot change the vfolders that are shared by other users, and the new name must be unique among all your accessible vfolders @@ -250,7 +259,7 @@ def rename(old_name, new_name) -> None: """ with Session() as session: try: - session.VFolder(old_name).rename(new_name) + _ = session.VFolder(old_name).rename(new_name) print_done("Renamed.") except Exception as e: print_error(e) @@ -259,7 +268,7 @@ def rename(old_name, new_name) -> None: @vfolder.command() @click.argument("name", type=str) -def info(name) -> None: +def info(name: str) -> None: """Show the information of the given virtual folder. \b @@ -324,7 +333,14 @@ def info(name) -> None: " include the protocol part and the port number to replace." ), ) -def upload(name, filenames, base_dir, recursive, chunk_size, override_storage_proxy) -> None: +def upload( + name: str, + filenames: tuple[Path, ...], + base_dir: Path | None, + recursive: bool, + chunk_size: int, + override_storage_proxy: dict[str, str] | None, +) -> None: """ TUS Upload a file to the virtual folder from the current working directory. The files with the same names will be overwritten. @@ -335,7 +351,7 @@ def upload(name, filenames, base_dir, recursive, chunk_size, override_storage_pr """ with Session() as session: try: - session.VFolder(name).upload( + _ = session.VFolder(name).upload( filenames, basedir=base_dir, recursive=recursive, @@ -389,7 +405,14 @@ def upload(name, filenames, base_dir, recursive, chunk_size, override_storage_pr default=20, help="Maximum retry attempt when any failure occurs.", ) -def download(name, filenames, base_dir, chunk_size, override_storage_proxy, max_retries) -> None: +def download( + name: str, + filenames: tuple[Path, ...], + base_dir: Path | None, + chunk_size: int, + override_storage_proxy: dict[str, str] | None, + max_retries: int, +) -> None: """ Download a file from the virtual folder to the current working directory. The files with the same names will be overwritten. @@ -400,7 +423,7 @@ def download(name, filenames, base_dir, chunk_size, override_storage_proxy, max_ """ with Session() as session: try: - session.VFolder(name).download( + _ = session.VFolder(name).download( filenames, basedir=base_dir, chunk_size=chunk_size, @@ -418,7 +441,7 @@ def download(name, filenames, base_dir, chunk_size, override_storage_proxy, max_ @vfolder.command() @click.argument("name", type=str) @click.argument("filename", type=Path) -def request_download(name, filename) -> None: +def request_download(name: str, filename: Path) -> None: """ Request JWT-formatted download token for later use. @@ -437,7 +460,7 @@ def request_download(name, filename) -> None: @vfolder.command() @click.argument("filenames", nargs=-1) -def cp(filenames) -> None: +def cp(filenames: tuple[str, ...]) -> None: """An scp-like shortcut for download/upload commands. \b @@ -496,7 +519,7 @@ def mkdir( @click.argument("name", type=str) @click.argument("target_path", type=str) @click.argument("new_name", type=str) -def rename_file(name, target_path, new_name) -> None: +def rename_file(name: str, target_path: str, new_name: str) -> None: """ Rename a file or a directory in a virtual folder. @@ -507,7 +530,7 @@ def rename_file(name, target_path, new_name) -> None: """ with Session() as session: try: - session.VFolder(name).rename_file(target_path, new_name) + _ = session.VFolder(name).rename_file(target_path, new_name) print_done("Renamed.") except Exception as e: print_error(e) @@ -518,7 +541,7 @@ def rename_file(name, target_path, new_name) -> None: @click.argument("name", type=str) @click.argument("src", type=str) @click.argument("dst", type=str) -def mv(name, src, dst) -> None: +def mv(name: str, src: str, dst: str) -> None: """ Move a file or a directory within a virtual folder. If the destination is a file and already exists, it will be overwritten. @@ -532,7 +555,7 @@ def mv(name, src, dst) -> None: """ with Session() as session: try: - session.VFolder(name).move_file(src, dst) + _ = session.VFolder(name).move_file(src, dst) print_done("Moved.") except Exception as e: print_error(e) @@ -543,7 +566,7 @@ def mv(name, src, dst) -> None: @click.argument("name", type=str) @click.argument("filenames", nargs=-1, required=True) @click.option("-r", "--recursive", is_flag=True, help="Enable recursive deletion of directories.") -def rm(name, filenames, recursive) -> None: +def rm(name: str, filenames: tuple[str, ...], recursive: bool) -> None: """ Delete files in a virtual folder. If one of the given paths is a directory and the recursive option is enabled, @@ -560,7 +583,7 @@ def rm(name, filenames, recursive) -> None: if not ask_yn(): print_info("Cancelled") sys.exit(ExitCode.FAILURE) - session.VFolder(name).delete_files(filenames, recursive=recursive) + _ = session.VFolder(name).delete_files(filenames, recursive=recursive) print_done("Done.") except Exception as e: print_error(e) @@ -570,7 +593,7 @@ def rm(name, filenames, recursive) -> None: @vfolder.command() @click.argument("name", type=str) @click.argument("path", metavar="PATH", nargs=1, default=".") -def ls(name, path) -> None: +def ls(name: str, path: str) -> None: """ List files in a path of a virtual folder. @@ -610,7 +633,7 @@ def ls(name, path) -> None: default="rw", help='Permission to give. "ro" (read-only) / "rw" (read-write) / "wd" (write-delete).', ) -def invite(name, emails, perm) -> None: +def invite(name: str, emails: tuple[str, ...], perm: str) -> None: """Invite other users to access a user-type virtual folder. \b @@ -667,7 +690,7 @@ def invitations() -> None: while True: action = input("Choose action. (a)ccept, (r)eject, (c)ancel: ") if action.lower() == "a": - session.VFolder.accept_invitation(invitations[selection]["id"]) + _ = session.VFolder.accept_invitation(invitations[selection]["id"]) msg = "You can now access vfolder {} ({})".format( invitations[selection]["vfolder_name"], invitations[selection]["id"], @@ -675,7 +698,7 @@ def invitations() -> None: print(msg) break if action.lower() == "r": - session.VFolder.delete_invitation(invitations[selection]["id"]) + _ = session.VFolder.delete_invitation(invitations[selection]["id"]) msg = "vfolder invitation rejected: {} ({})".format( invitations[selection]["vfolder_name"], invitations[selection]["id"], @@ -700,7 +723,7 @@ def invitations() -> None: default="rw", help='Permission to give. "ro" (read-only) / "rw" (read-write) / "wd" (write-delete).', ) -def share(name, emails, perm) -> None: +def share(name: str, emails: tuple[str, ...], perm: str) -> None: """Share a group folder to users with overriding permission. \b @@ -727,7 +750,7 @@ def share(name, emails, perm) -> None: @vfolder.command() @click.argument("name", type=str) @click.argument("emails", type=str, nargs=-1, required=True) -def unshare(name, emails) -> None: +def unshare(name: str, emails: tuple[str, ...]) -> None: """Unshare a group folder from users. \b @@ -759,7 +782,7 @@ def unshare(name, emails) -> None: default=None, help="The ID of the person who wants to leave (the person who shared the vfolder).", ) -def leave(name, shared_user_uuid) -> None: +def leave(name: str, shared_user_uuid: str | None) -> None: """Leave the shared virtual folder. \b @@ -774,7 +797,7 @@ def leave(name, shared_user_uuid) -> None: if vfolder_info["is_owner"]: print("You cannot leave a virtual folder you own. Consider using delete instead.") return - session.VFolder(name).leave(shared_user_uuid) + _ = session.VFolder(name).leave(shared_user_uuid) print(f'Left the shared virtual folder "{name}".') except Exception as e: @@ -802,7 +825,7 @@ def leave(name, shared_user_uuid) -> None: default="rw", help="Cloned virtual folder's permission. Default value is 'rw'.", ) -def clone(name, target_name, target_host, usage_mode, permission) -> None: +def clone(name: str, target_name: str, target_host: str, usage_mode: str, permission: str) -> None: """Clone a virtual folder. \b @@ -895,7 +918,7 @@ async def clone_vfolder_tracker(bgtask_id: str) -> None: "If not set, the cloneable property is not changed." ), ) -def update_options(name, permission, set_cloneable) -> None: +def update_options(name: str, permission: str | None, set_cloneable: bool | None) -> None: """Update an existing virtual folder. \b @@ -907,7 +930,7 @@ def update_options(name, permission, set_cloneable) -> None: if not vfolder_info["is_owner"]: print("You cannot update virtual folder that you do not own.") return - session.VFolder(name).update_options( + _ = session.VFolder(name).update_options( name, permission=permission, cloneable=set_cloneable, @@ -975,7 +998,9 @@ def update_options(name, permission, set_cloneable) -> None: ) @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list_own(ctx: CLIContext, filter_, order, offset, limit) -> None: +def list_own( + ctx: CLIContext, filter_: str | None, order: str | None, offset: int, limit: int | None +) -> None: """ List own virtual folders. """ @@ -1055,7 +1080,9 @@ def list_own(ctx: CLIContext, filter_, order, offset, limit) -> None: ) @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list_invited(ctx: CLIContext, filter_, order, offset, limit) -> None: +def list_invited( + ctx: CLIContext, filter_: str | None, order: str | None, offset: int, limit: int | None +) -> None: """ List invited virtual folders. """ @@ -1135,7 +1162,9 @@ def list_invited(ctx: CLIContext, filter_, order, offset, limit) -> None: ) @click.option("--offset", default=0, help="The index of the current page start for pagination.") @click.option("--limit", type=int, default=None, help="The page size for pagination.") -def list_project(ctx: CLIContext, filter_, order, offset, limit) -> None: +def list_project( + ctx: CLIContext, filter_: str | None, order: str | None, offset: int, limit: int | None +) -> None: """ List project virtual folders. """ diff --git a/src/ai/backend/client/compat.py b/src/ai/backend/client/compat.py index 6563e5a936e..85d4f336940 100644 --- a/src/ai/backend/client/compat.py +++ b/src/ai/backend/client/compat.py @@ -27,7 +27,7 @@ all_tasks = asyncio.Task.all_tasks # type: ignore -def _cancel_all_tasks(loop) -> None: +def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: to_cancel = all_tasks(loop) if not to_cancel: return @@ -45,7 +45,7 @@ def _cancel_all_tasks(loop) -> None: }) -def _asyncio_run(coro, *, debug=False) -> Any: +def _asyncio_run(coro: Any, *, debug: bool = False) -> Any: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.set_debug(debug) @@ -68,7 +68,7 @@ def _asyncio_run(coro, *, debug=False) -> Any: asyncio_run = _asyncio_run # type: ignore[assignment] -def asyncio_run_forever(server_context, *, debug=False) -> Any: +def asyncio_run_forever(server_context: Any, *, debug: bool = False) -> Any: """ A proposed-but-not-implemented asyncio.run_forever() API based on @vxgmichel's idea. diff --git a/src/ai/backend/client/func/base.py b/src/ai/backend/client/func/base.py index 5380032a32e..d734bdf191c 100644 --- a/src/ai/backend/client/func/base.py +++ b/src/ai/backend/client/func/base.py @@ -18,7 +18,7 @@ def _wrap_method(cls: type, orig_name: str, meth: Callable) -> Callable: @functools.wraps(meth) - def _method(*args, **kwargs) -> Any: + def _method(*args: Any, **kwargs: Any) -> Any: # We need to keep the original attributes so that they could be correctly # bound to the class/instance at runtime. func = getattr(cls, orig_name) @@ -40,11 +40,11 @@ def _method(*args, **kwargs) -> Any: return _method -def api_function(meth) -> None: +def api_function[T: Callable](meth: T) -> T: """ Mark the wrapped method as the API function method. """ - meth._backend_api = True + meth._backend_api = True # type: ignore[attr-defined] return meth @@ -65,7 +65,7 @@ def field_resolver( default_fields: Iterable[FieldSpec], ) -> Callable: def decorator(meth: Callable) -> Callable: - def wrapper(*args, **kwargs) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: if fields := kwargs.get("fields", default_fields): resolved_fields = tuple( f.field_ref if isinstance(f, FieldSpec) else base_field_set[f].field_ref @@ -88,7 +88,9 @@ class APIFunctionMeta(type): _async = True - def __init__(cls, name, bases, attrs, **kwargs) -> None: + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], **kwargs: Any + ) -> None: super().__init__(name, bases, attrs) for attr_name, attr_value in attrs.items(): if hasattr(attr_value, "_backend_api"): diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index 8caa302ea84..23933ea69c0 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -685,7 +685,7 @@ async def restart(self) -> None: pass @api_function - async def rename(self, new_name) -> None: + async def rename(self, new_name: str) -> None: """ Renames Session ID of running compute session. """ @@ -1491,7 +1491,7 @@ async def restart(self) -> None: raise NotImplementedError @api_function - async def rename(self, new_id) -> None: + async def rename(self, new_id: str) -> None: """ Renames Session ID or running inference session. """ diff --git a/src/ai/backend/client/func/vfolder.py b/src/ai/backend/client/func/vfolder.py index d8e5cc78df4..ff1e00d3f2c 100644 --- a/src/ai/backend/client/func/vfolder.py +++ b/src/ai/backend/client/func/vfolder.py @@ -88,7 +88,7 @@ async def create( @api_function @classmethod - async def delete_by_id(cls, oid) -> dict[str, Any]: + async def delete_by_id(cls, oid: str) -> dict[str, Any]: rqst = Request("DELETE", "/folders") rqst.set_json({"id": oid}) async with rqst.fetch(): @@ -96,7 +96,7 @@ async def delete_by_id(cls, oid) -> dict[str, Any]: @api_function @classmethod - async def list(cls, list_all=False) -> dict[str, Any]: + async def list(cls, list_all: bool = False) -> dict[str, Any]: rqst = Request("GET", "/folders") rqst.set_json({"all": list_all}) async with rqst.fetch() as resp: @@ -309,7 +309,7 @@ async def force_delete(self) -> dict[str, Any]: return {} @api_function - async def rename(self, new_name) -> str: + async def rename(self, new_name: str) -> str: await self.update_id_by_name() rqst = Request("POST", f"/folders/{self.request_key}/rename") rqst.set_json({ @@ -671,13 +671,14 @@ async def delete_invitation(cls, inv_id: str) -> dict[str, Any]: @api_function @classmethod - async def get_fstab_contents(cls, agent_id=None) -> dict[str, Any]: + async def get_fstab_contents(cls, agent_id: Optional[str] = None) -> dict[str, Any]: + params: dict[str, str | int] = {} + if agent_id is not None: + params["agent_id"] = agent_id rqst = Request( "GET", "/folders/_/fstab", - params={ - "agent_id": agent_id, - }, + params=params, ) async with rqst.fetch() as resp: return await resp.json() @@ -699,7 +700,11 @@ async def list_mounts(cls) -> dict[str, Any]: @api_function @classmethod async def mount_host( - cls, name: str, fs_location: str, options=None, edit_fstab: bool = False + cls, + name: str, + fs_location: str, + options: Optional[dict[str, Any]] = None, + edit_fstab: bool = False, ) -> dict[str, Any]: rqst = Request("POST", "/folders/_/mounts") rqst.set_json({ @@ -744,7 +749,7 @@ async def unshare(self, emails: Sequence[str]) -> dict[str, Any]: return await resp.json() @api_function - async def leave(self, shared_user_uuid=None) -> dict[str, Any]: + async def leave(self, shared_user_uuid: Optional[str] = None) -> dict[str, Any]: await self.update_id_by_name() rqst = Request("POST", f"/folders/{self.request_key}/leave") rqst.set_json({"shared_user_uuid": shared_user_uuid}) diff --git a/src/ai/backend/client/output/console.py b/src/ai/backend/client/output/console.py index 20bf8036cb7..f854234cfc9 100644 --- a/src/ai/backend/client/output/console.py +++ b/src/ai/backend/client/output/console.py @@ -147,7 +147,7 @@ def print_paginated_list( fetch_func: Callable[[int, int], PaginatedResult], initial_page_offset: int, page_size: Optional[int] = None, - plain=False, + plain: bool = False, ) -> None: fields: list[FieldSpec] = [] diff --git a/src/ai/backend/client/output/formatters.py b/src/ai/backend/client/output/formatters.py index b5a30bee747..0201fba4c74 100644 --- a/src/ai/backend/client/output/formatters.py +++ b/src/ai/backend/client/output/formatters.py @@ -14,7 +14,7 @@ from .types import AbstractOutputFormatter, FieldSpec -def format_stats(raw_stats: Optional[str], indent="") -> str: +def format_stats(raw_stats: Optional[str], indent: str = "") -> str: if raw_stats is None: return "(unavailable)" stats = json.loads(raw_stats) @@ -318,7 +318,7 @@ def _fit_multiline_in_cell(text: str, indent: str) -> str: class ContainerListFormatter(NestedObjectFormatter): - def format_console(self, value: Any, field: FieldSpec, indent="") -> str: + def format_console(self, value: Any, field: FieldSpec, indent: str = "") -> str: if not isinstance(value, list): raise ValueError("ContainerListFormatter expects a list value") if len(value) == 0: @@ -337,7 +337,7 @@ def format_console(self, value: Any, field: FieldSpec, indent="") -> str: class DependencyListFormatter(NestedObjectFormatter): - def format_console(self, value: Any, field: FieldSpec, indent="") -> str: + def format_console(self, value: Any, field: FieldSpec, indent: str = "") -> str: if not isinstance(value, list): raise ValueError("DependencyListFormatter expects a list value") if len(value) == 0: diff --git a/src/ai/backend/client/output/json.py b/src/ai/backend/client/output/json.py index e774ada6706..7a864399647 100644 --- a/src/ai/backend/client/output/json.py +++ b/src/ai/backend/client/output/json.py @@ -123,7 +123,7 @@ def print_paginated_list( fetch_func: Callable[[int, int], PaginatedResult], initial_page_offset: int, page_size: Optional[int] = None, - plain=False, + plain: bool = False, ) -> None: page_size = page_size or 20 result = fetch_func(initial_page_offset, page_size) diff --git a/src/ai/backend/client/request.py b/src/ai/backend/client/request.py index ec3beff1e01..1f2521e52a2 100644 --- a/src/ai/backend/client/request.py +++ b/src/ai/backend/client/request.py @@ -434,7 +434,7 @@ class AsyncResponseMixin: async def text(self) -> str: return await self._raw_response.text() - async def json(self, *, loads=modjson.loads) -> Any: + async def json(self, *, loads: Callable[[str], Any] = modjson.loads) -> Any: loads = functools.partial(loads) return await self._raw_response.json(loads=loads) @@ -455,7 +455,7 @@ def text(self) -> str: self._raw_response.text(), ) - def json(self, *, loads=modjson.loads) -> Any: + def json(self, *, loads: Callable[[str], Any] = modjson.loads) -> Any: loads = functools.partial(loads) sync_session = cast(SyncSession, self._session) return sync_session.worker_thread.execute( diff --git a/src/ai/backend/client/session.py b/src/ai/backend/client/session.py index 3efe3510e36..9d7cdb9962f 100644 --- a/src/ai/backend/client/session.py +++ b/src/ai/backend/client/session.py @@ -99,7 +99,7 @@ async def _close_aiohttp_session(session: aiohttp.ClientSession) -> None: orig_lost = proto.connection_lost orig_eof_received = proto.eof_received - def connection_lost(exc) -> None: + def connection_lost(exc: Exception | None) -> None: orig_lost(exc) nonlocal transports transports -= 1 @@ -186,7 +186,7 @@ def execute(self, coro: Coroutine) -> Any: finally: del ctx - async def agen_wrapper(self, agen) -> None: + async def agen_wrapper(self, agen: AsyncIterator[Any]) -> None: self.agen_shutdown = False try: async for item in agen: @@ -569,7 +569,7 @@ async def __aexit__(self, *exc_info) -> Literal[False]: # TODO: Remove this after refactoring session management with contextvars @actxmgr -async def set_api_context(session) -> AsyncIterator[None]: +async def set_api_context(session: BaseSession) -> AsyncIterator[None]: token = api_session.set(session) try: yield diff --git a/src/ai/backend/client/utils.py b/src/ai/backend/client/utils.py index e2cbe961d97..97fb9b48696 100644 --- a/src/ai/backend/client/utils.py +++ b/src/ai/backend/client/utils.py @@ -3,6 +3,7 @@ import os import textwrap import uuid +from typing import Any from ai.backend.client.output.types import FieldSet, FieldSpec @@ -61,7 +62,7 @@ def flatten_connections_in_data(data: dict) -> dict: class ProgressReportingReader(io.BufferedReader): - def __init__(self, file_path, *, tqdm_instance=None) -> None: + def __init__(self, file_path: str, *, tqdm_instance: Any = None) -> None: super().__init__(open(file_path, "rb")) self._filename = os.path.basename(file_path) if tqdm_instance is None: @@ -80,7 +81,12 @@ def __init__(self, file_path, *, tqdm_instance=None) -> None: def __enter__(self) -> "ProgressReportingReader": return self - def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: Any, + ) -> None: if self._owns_tqdm: self.tqdm.close() self.close() diff --git a/src/ai/backend/client/versioning.py b/src/ai/backend/client/versioning.py index 512cced210b..83517f3398b 100644 --- a/src/ai/backend/client/versioning.py +++ b/src/ai/backend/client/versioning.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from .func.session import ComputeSession + from .session import BaseSession naming_profile = { @@ -36,7 +37,7 @@ def get_id_or_name(api_version: tuple[int, str], obj: ComputeSession) -> str: def apply_version_aware_fields( - api_session, + api_session: BaseSession, fields: Sequence[tuple[str, Callable | str]], ) -> Sequence[tuple[str, str]]: version_aware_fields = [] diff --git a/src/ai/backend/common/cli.py b/src/ai/backend/common/cli.py index 5529641b0c2..5148af518c0 100644 --- a/src/ai/backend/common/cli.py +++ b/src/ai/backend/common/cli.py @@ -10,9 +10,9 @@ import click -def wrap_method(method) -> Callable: +def wrap_method(method: Callable) -> Callable: @functools.wraps(method) - def wrapped(self, *args, **kwargs) -> Any: + def wrapped(self: Any, *args, **kwargs) -> Any: return method(self._impl, *args, **kwargs) return wrapped @@ -31,7 +31,7 @@ class LazyClickMixin: _import_name: str _loaded_impl: Optional[click.Command | click.Group] - def __init__(self, *, import_name, **kwargs) -> None: + def __init__(self, *, import_name: str, **kwargs) -> None: self._import_name = import_name self._loaded_impl = None super().__init__(**kwargs) @@ -63,14 +63,16 @@ def __init__(self, enum: type[Enum]) -> None: super().__init__(enum_members) self.enum = enum - def convert(self, value: Any, param, ctx) -> Enum: + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Enum: if isinstance(value, self.enum): # for default value, it is already the enum type. return next(e for e in self.enum if e == value) value = super().convert(value, param, ctx) return self.enum[value] - def get_metavar(self, param) -> str: + def get_metavar(self, param: click.Parameter) -> str: name = self.enum.__name__ name = re.sub(r"([A-Z\d]+)([A-Z][a-z])", r"\1_\2", name) name = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", name) @@ -80,7 +82,9 @@ def get_metavar(self, param) -> str: class MinMaxRangeParamType(click.ParamType): name = "min-max decimal range" - def convert(self, value, param, ctx) -> tuple[Decimal | None, Decimal | None]: + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> tuple[Decimal | None, Decimal | None]: try: left, _, right = value.partition(":") if left: @@ -95,7 +99,7 @@ def convert(self, value, param, ctx) -> tuple[Decimal | None, Decimal | None]: except (ArithmeticError, ValueError): self.fail(f"{value!r} contains an invalid number", param, ctx) - def get_metavar(self, param) -> str: + def get_metavar(self, param: click.Parameter) -> str: return "MIN:MAX" diff --git a/src/ai/backend/common/configs/redis.py b/src/ai/backend/common/configs/redis.py index df7268218c9..f10ebde1954 100644 --- a/src/ai/backend/common/configs/redis.py +++ b/src/ai/backend/common/configs/redis.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import Annotated, Any, Optional from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator @@ -230,12 +230,12 @@ def _parse_sentinel( raise TypeError("sentinel must be list or 'host:port,host:port' string") @field_serializer("addr") - def _serialize_addr(self, addr: Optional[HostPortPairModel], _info) -> Optional[str]: + def _serialize_addr(self, addr: Optional[HostPortPairModel], _info: Any) -> Optional[str]: return None if addr is None else f"{addr.host}:{addr.port}" @field_serializer("sentinel") def _serialize_sentinel( - self, sentinel: Optional[list[HostPortPairModel]], _info + self, sentinel: Optional[list[HostPortPairModel]], _info: Any ) -> Optional[str]: if sentinel is None: return None diff --git a/src/ai/backend/common/dependencies/stacks/builder.py b/src/ai/backend/common/dependencies/stacks/builder.py index e44df086007..79f711b95ed 100644 --- a/src/ai/backend/common/dependencies/stacks/builder.py +++ b/src/ai/backend/common/dependencies/stacks/builder.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import AsyncExitStack -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from ai.backend.common.dependencies.base import DependencyStack @@ -102,7 +102,9 @@ async def __aenter__(self) -> DependencyBuilderStack: await self._stack.__aenter__() return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> bool | None: """ Exit the async context and cleanup resources in LIFO order. """ diff --git a/src/ai/backend/common/dependencies/stacks/visualizing.py b/src/ai/backend/common/dependencies/stacks/visualizing.py index bcb6fddb001..c008a5ba804 100644 --- a/src/ai/backend/common/dependencies/stacks/visualizing.py +++ b/src/ai/backend/common/dependencies/stacks/visualizing.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from enum import Enum -from typing import TextIO +from typing import Any, TextIO from ai.backend.common.dependencies.base import ( DependencyComposer, @@ -189,7 +189,9 @@ async def __aenter__(self) -> VisualizingDependencyStack: await self._stack.__aenter__() return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> bool | None: """ Exit the async context and cleanup resources in LIFO order. """ diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index ca323e9e974..ae12cfbcf63 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -13,6 +13,7 @@ from pathlib import Path, PurePath from typing import ( TYPE_CHECKING, + Any, Final, Literal, NamedTuple, @@ -401,7 +402,7 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._data) - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, (set, frozenset)): return set(self._data.keys()) == other return self._data == other @@ -624,7 +625,9 @@ def generate_aliases(self) -> Mapping[str, ImageRef]: return ret @staticmethod - def merge_aliases(genned_aliases_1, genned_aliases_2) -> Mapping[str, ImageRef]: + def merge_aliases( + genned_aliases_1: Mapping[str, ImageRef], genned_aliases_2: Mapping[str, ImageRef] + ) -> Mapping[str, ImageRef]: ret = {} aliases_set_1, aliases_set_2 = set(genned_aliases_1.keys()), set(genned_aliases_2.keys()) aliases_dup = aliases_set_1 & aliases_set_2 @@ -668,7 +671,9 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((self.project, self.name, self.tag, self.registry, self.architecture)) - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: + if not isinstance(other, ImageRef): + return NotImplemented if self == other: # call __eq__ first for resolved check return False if not (self.name == other.name and self.project == other.project): diff --git a/src/ai/backend/common/enum_extension.py b/src/ai/backend/common/enum_extension.py index d1cdb45678f..4be4853d30d 100644 --- a/src/ai/backend/common/enum_extension.py +++ b/src/ai/backend/common/enum_extension.py @@ -1,18 +1,19 @@ from __future__ import annotations import enum +from typing import Any __all__ = ("StringSetFlag",) class StringSetFlag(enum.StrEnum): - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: return self.value == other def __hash__(self) -> int: return hash(self.value) - def __or__(self, other) -> set[str]: + def __or__(self, other: Any) -> set[str]: if isinstance(other, type(self)): other = other.value if not isinstance(other, (set, frozenset)): @@ -21,7 +22,7 @@ def __or__(self, other) -> set[str]: __ror__ = __or__ - def __and__(self, other) -> bool: + def __and__(self, other: Any) -> bool: if isinstance(other, (set, frozenset)): return self.value in other if isinstance(other, str): @@ -30,7 +31,7 @@ def __and__(self, other) -> bool: __rand__ = __and__ - def __xor__(self, other) -> set[str] | str: + def __xor__(self, other: Any) -> set[str] | str: if isinstance(other, (set, frozenset)): return {self.value} ^ other if isinstance(other, str): @@ -39,7 +40,7 @@ def __xor__(self, other) -> set[str] | str: return other raise TypeError - def __rxor__(self, other) -> set[str] | str: + def __rxor__(self, other: Any) -> set[str] | str: if isinstance(other, (set, frozenset)): return set(other) ^ {self.value} if isinstance(other, str): diff --git a/src/ai/backend/common/etcd.py b/src/ai/backend/common/etcd.py index 49a7277a570..298bcfd5dfd 100644 --- a/src/ai/backend/common/etcd.py +++ b/src/ai/backend/common/etcd.py @@ -82,11 +82,13 @@ class ConfigScopes(enum.Enum): quote = functools.partial(_quote, safe="") -def make_dict_from_pairs(key_prefix, pairs, path_sep="/") -> dict[str, Any]: +def make_dict_from_pairs( + key_prefix: str, pairs: Iterable[tuple[str, str]], path_sep: str = "/" +) -> dict[str, Any]: result: dict[str, Any] = {} len_prefix = len(key_prefix) if isinstance(pairs, dict): - iterator = pairs.items() + iterator: Iterable[tuple[str, str]] = pairs.items() else: iterator = pairs for k, v in iterator: @@ -574,7 +576,7 @@ async def get_prefix( scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] else: raise ValueError("Invalid scope prefix value") - pair_sets: list[list[Mapping | tuple]] = [] + pair_sets: list[list[tuple[str, str]]] = [] async with self.etcd.connect() as communicator: for scope_prefix in scope_prefixes: mangled_key_prefix = self._mangle_key(f"{_slash(scope_prefix)}{key_prefix}") diff --git a/src/ai/backend/common/files.py b/src/ai/backend/common/files.py index 9507c7b0252..432d84c4210 100644 --- a/src/ai/backend/common/files.py +++ b/src/ai/backend/common/files.py @@ -2,7 +2,7 @@ from collections.abc import Callable from pathlib import Path -from typing import Optional +from typing import Any, Optional import janus @@ -54,7 +54,9 @@ def _write(self) -> None: f.write(item) self._q.sync_q.task_done() - async def __aexit__(self, exc_type, exc, tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any + ) -> None: await self._q.async_q.put(Sentinel.TOKEN) try: await self._fut @@ -62,5 +64,5 @@ async def __aexit__(self, exc_type, exc, tb) -> None: self._q.close() await self._q.wait_closed() - async def write(self, item) -> None: + async def write(self, item: str | bytes) -> None: await self._q.async_q.put(item) diff --git a/src/ai/backend/common/json.py b/src/ai/backend/common/json.py index 84ee8481696..93a6ad700b3 100644 --- a/src/ai/backend/common/json.py +++ b/src/ai/backend/common/json.py @@ -7,7 +7,7 @@ class ExtendedJSONEncoder(json.JSONEncoder): - def default(self, o) -> Any: + def default(self, o: Any) -> Any: if isinstance(o, uuid.UUID): return str(o) if isinstance(o, datetime.datetime): diff --git a/src/ai/backend/common/redis_client.py b/src/ai/backend/common/redis_client.py index 1d1b9d69439..4797746e92d 100644 --- a/src/ai/backend/common/redis_client.py +++ b/src/ai/backend/common/redis_client.py @@ -58,7 +58,7 @@ def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - verbose=False, + verbose: bool = False, ) -> None: self.reader = reader self.writer = writer @@ -99,7 +99,7 @@ async def pipeline( commands: Sequence[Sequence[str | int | float | bytes | memoryview]], *, command_timeout: float | None = None, - return_exception=False, + return_exception: bool = False, ) -> Any: return await self._send( commands, @@ -112,7 +112,7 @@ async def _send( commands: Sequence[Sequence[str | int | float | bytes | memoryview]], *, command_timeout: float | None = None, - return_exception=False, + return_exception: bool = False, ) -> list[Any]: """ Executes a function that issues Redis commands or returns a pipeline/transaction of commands, diff --git a/src/ai/backend/common/testutils.py b/src/ai/backend/common/testutils.py index c28e66a0686..9192eca13a5 100644 --- a/src/ai/backend/common/testutils.py +++ b/src/ai/backend/common/testutils.py @@ -9,7 +9,7 @@ from asynctest import CoroutineMock as AsyncMock # type: ignore -def mock_corofunc(return_value) -> mock.Mock: +def mock_corofunc(return_value: Any) -> mock.Mock: """ Return mock coroutine function. @@ -50,5 +50,7 @@ def __init__(self, *args, **kwargs) -> None: async def __aenter__(self) -> "AsyncMock": return AsyncMock(**self.context) - async def __aexit__(self, exc_type, exc_value, exc_tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: Any + ) -> None: pass diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 7db50564498..02a527be7cd 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -830,7 +830,7 @@ def __str__(self) -> str: value = self._quantize(self, multiplier) return f"{value:f} {suffix.upper()}iB" - def __format__(self, format_spec) -> str: + def __format__(self, format_spec: str) -> str: if len(format_spec) != 1: raise ValueError("format-string for BinarySize can be only one character.") if format_spec == "s": @@ -956,7 +956,7 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: if not isinstance(other, ResourceSlot): - raise TypeError("Only can compare ResourceSlot objects.") + return NotImplemented self.sync_keys(other) return not self.__eq__(other) @@ -978,17 +978,17 @@ def eq_contained(self, other: ResourceSlot) -> bool: other_values = [other.data[k] for k in common_keys] return self_values == other_values and all(self[k] == 0 for k in only_self_keys) - def __le__(self, other: ResourceSlot) -> bool: + def __le__(self, other: object) -> bool: if not isinstance(other, ResourceSlot): - raise TypeError("Only can compare ResourceSlot objects.") + return NotImplemented self.sync_keys(other) self_values = [self.data[k] for k in self.keys()] other_values = [other.data[k] for k in self.keys()] return not any(s > o for s, o in zip(self_values, other_values, strict=True)) - def __lt__(self, other: ResourceSlot) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, ResourceSlot): - raise TypeError("Only can compare ResourceSlot objects.") + return NotImplemented self.sync_keys(other) self_values = [self.data[k] for k in self.keys()] other_values = [other.data[k] for k in self.keys()] @@ -996,17 +996,17 @@ def __lt__(self, other: ResourceSlot) -> bool: self_values == other_values ) - def __ge__(self, other: ResourceSlot) -> bool: + def __ge__(self, other: object) -> bool: if not isinstance(other, ResourceSlot): - raise TypeError("Only can compare ResourceSlot objects.") + return NotImplemented self.sync_keys(other) self_values = [self.data[k] for k in other.keys()] other_values = [other.data[k] for k in other.keys()] return not any(s < o for s, o in zip(self_values, other_values, strict=True)) - def __gt__(self, other: ResourceSlot) -> bool: + def __gt__(self, other: object) -> bool: if not isinstance(other, ResourceSlot): - raise TypeError("Only can compare ResourceSlot objects.") + return NotImplemented self.sync_keys(other) self_values = [self.data[k] for k in other.keys()] other_values = [other.data[k] for k in other.keys()] @@ -1228,7 +1228,7 @@ def __str__(self) -> str: return self.folder_id.hex return f"{self.quota_scope_id}/{self.folder_id.hex}" - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: return self.quota_scope_id == other.quota_scope_id and self.folder_id == other.folder_id def __hash__(self) -> int: @@ -1744,7 +1744,7 @@ def profile_target(self, role: RedisRole) -> RedisTarget: return self._base_target @staticmethod - def _parse_addr(addr_data) -> HostPortPair: + def _parse_addr(addr_data: Any) -> HostPortPair: match addr_data: case HostPortPair(host=host, port=port): return HostPortPair(host, port) diff --git a/src/ai/backend/common/utils.py b/src/ai/backend/common/utils.py index 2ead92a0311..51c839c5b3d 100644 --- a/src/ai/backend/common/utils.py +++ b/src/ai/backend/common/utils.py @@ -266,7 +266,7 @@ class FstabEntry: """ def __init__( - self, device: str, mountpoint: str, fstype: str, options: str | None, d=0, p=0 + self, device: str, mountpoint: str, fstype: str, options: str | None, d: int = 0, p: int = 0 ) -> None: self.device = device self.mountpoint = mountpoint @@ -277,7 +277,7 @@ def __init__( self.d = d self.p = p - def __eq__(self, o) -> bool: + def __eq__(self, o: Any) -> bool: return str(self) == str(o) def __str__(self) -> str: @@ -300,7 +300,16 @@ def __init__(self, fp: AsyncTextIOWrapper) -> None: self._fp = fp def _hydrate_entry(self, line: str) -> FstabEntry: - return FstabEntry(*[x for x in line.strip("\n").split(" ") if x not in ("", None)]) + parts = [x for x in line.strip("\n").split(" ") if x not in ("", None)] + # Ensure we have at least 4 parts (device, mountpoint, fstype, options) + # and convert the last two to integers if present + device = parts[0] if len(parts) > 0 else "" + mountpoint = parts[1] if len(parts) > 1 else "" + fstype = parts[2] if len(parts) > 2 else "" + options = parts[3] if len(parts) > 3 else None + d = int(parts[4]) if len(parts) > 4 else 0 + p = int(parts[5]) if len(parts) > 5 else 0 + return FstabEntry(device, mountpoint, fstype, options, d, p) async def get_entries(self) -> AsyncIterator[FstabEntry]: await self._fp.seek(0) @@ -326,7 +335,13 @@ async def add_entry(self, entry: FstabEntry) -> None: await self._fp.truncate() async def add( - self, device: str, mountpoint: str, fstype: str, options: str | None = None, d=0, p=0 + self, + device: str, + mountpoint: str, + fstype: str, + options: str | None = None, + d: int = 0, + p: int = 0, ) -> None: return await self.add_entry(FstabEntry(device, mountpoint, fstype, options, d, p)) diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 15722d9c2e4..1bdeef8f951 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -73,7 +73,7 @@ class StringLengthMeta(TrafaretMeta): A metaclass that makes string-like trafarets to have sliced min/max length indicator. """ - def __getitem__(cls, slice_) -> t.Trafaret: + def __getitem__(cls, slice_: slice) -> t.Trafaret: return cls(min_length=slice_.start, max_length=slice_.stop) @@ -90,7 +90,7 @@ def __init__(self, names: Sequence[str], **kwargs) -> None: super().__init__(names[0], **kwargs) self.names = names - def __call__(self, data, context=None) -> Generator[tuple, None, None]: # type: ignore[override] + def __call__(self, data: Any, context: Any = None) -> Generator[tuple, None, None]: # type: ignore[override] for name in self.names: if name in data: key = name @@ -121,7 +121,7 @@ def __call__(self, data, context=None) -> Generator[tuple, None, None]: # type: class MultiKey(t.Key): - def get_data(self, data, default) -> list[Any]: + def get_data(self, data: Any, default: Any) -> list[Any]: if isinstance(data, (multidict.MultiDict, multidict.MultiDictProxy)): return data.getall(self.name, default) # type: ignore[attr-defined] # fallback for plain dicts diff --git a/src/ai/backend/helpers/package.py b/src/ai/backend/helpers/package.py index 0eaf1f237ba..cb2a5260725 100644 --- a/src/ai/backend/helpers/package.py +++ b/src/ai/backend/helpers/package.py @@ -11,7 +11,7 @@ __all__ = ("install",) -def install(pkgname, force_install=False) -> None: +def install(pkgname: str, force_install: bool = False) -> None: """ Install a Python package from pypi.org or the given index server. The package is installed inside the user site directory. diff --git a/src/ai/backend/install/context.py b/src/ai/backend/install/context.py index 84b0033595d..ab9715ac2a2 100644 --- a/src/ai/backend/install/context.py +++ b/src/ai/backend/install/context.py @@ -125,7 +125,7 @@ def log_header(self, title: str) -> None: def mangle_pkgname(self, name: str, fat: bool = False) -> str: return f"backendai-{name}-{self.os_info.platform}" - def generate_passphrase(self, len=16) -> str: + def generate_passphrase(self, len: int = 16) -> str: return "".join(random.sample(PASSPHRASE_CHARACTER_POOL, len)) @staticmethod diff --git a/src/ai/backend/install/docker.py b/src/ai/backend/install/docker.py index 4734e9bb576..8222f898e52 100644 --- a/src/ai/backend/install/docker.py +++ b/src/ai/backend/install/docker.py @@ -28,8 +28,8 @@ ) -def parse_version(expr) -> tuple: - result = [] +def parse_version(expr: str) -> tuple[int | str, ...]: + result: list[int | str] = [] for part in expr.split("."): try: result.append(int(part)) diff --git a/src/ai/backend/install/widgets.py b/src/ai/backend/install/widgets.py index 82a3b862af4..0c23d1334b6 100644 --- a/src/ai/backend/install/widgets.py +++ b/src/ai/backend/install/widgets.py @@ -10,6 +10,7 @@ from textual.app import ComposeResult from textual.binding import Binding from textual.containers import Horizontal +from textual.events import Mount from textual.validation import ValidationResult, Validator from textual.widget import Widget from textual.widgets import ( @@ -176,7 +177,7 @@ def compose(self) -> ComposeResult: if self._allow_cancel: yield Button("Cancel", id="button-cancel") - def on_mount(self, _) -> None: + def on_mount(self, _: Mount) -> None: self._override_focus() self.query_one(Input).focus() diff --git a/src/ai/backend/kernel/app/__init__.py b/src/ai/backend/kernel/app/__init__.py index f140cc92573..7db3278c117 100644 --- a/src/ai/backend/kernel/app/__init__.py +++ b/src/ai/backend/kernel/app/__init__.py @@ -4,6 +4,8 @@ """ import logging +from collections.abc import Mapping +from typing import Any from ai.backend.kernel import BaseRunner @@ -30,6 +32,6 @@ async def execute_heuristic(self) -> int: log.warning("batch-mode execution is not supported") return 0 - async def start_service(self, service_info) -> tuple[None, dict]: + async def start_service(self, service_info: Mapping[str, Any]) -> tuple[None, dict]: # app kernels use service-definition templates. return None, {} diff --git a/src/ai/backend/kernel/base.py b/src/ai/backend/kernel/base.py index 2a8f314e3ae..68363cce648 100644 --- a/src/ai/backend/kernel/base.py +++ b/src/ai/backend/kernel/base.py @@ -67,11 +67,16 @@ class HealthStatus(enum.Enum): UNDETERMINED = 2 -async def pipe_output(stream, outsock, target, log_fd) -> None: +async def pipe_output( + stream: asyncio.StreamReader, + outsock: zmq.Socket, + target: str, + log_fd: int, +) -> None: if target not in ("stdout", "stderr"): raise ValueError(f"Invalid target: {target}. Must be 'stdout' or 'stderr'") - target = target.encode("ascii") console_fd = sys.stdout.fileno() if target == "stdout" else sys.stderr.fileno() + target_bytes = target.encode("ascii") loop = current_loop() try: while True: @@ -81,7 +86,7 @@ async def pipe_output(stream, outsock, target, log_fd) -> None: await asyncio.gather( loop.run_in_executor(None, os.write, console_fd, data), loop.run_in_executor(None, os.write, log_fd, data), - outsock.send_multipart([target, data]), + outsock.send_multipart([target_bytes, data]), return_exceptions=True, ) except asyncio.CancelledError: @@ -508,7 +513,7 @@ async def _query(self, code_text: str) -> None: }).encode("utf8") await self.outsock.send_multipart([b"finished", payload]) - async def query(self, code_text) -> int: + async def query(self, code_text: str) -> int: """Run user's code in query mode. The default interface is jupyter kernel. To use different interface, @@ -523,7 +528,7 @@ async def query(self, code_text) -> int: log.debug("executing in query mode...") exit_code = 0 - async def output_hook(msg) -> None: + async def output_hook(msg: Mapping[str, Any]) -> None: nonlocal exit_code content = msg.get("content", "") if msg["msg_type"] == "stream": @@ -577,7 +582,7 @@ async def output_hook(msg) -> None: json.dumps({"type": dtype, "data": dval}).encode("utf8"), ]) - async def stdin_hook(msg) -> None: + async def stdin_hook(msg: Mapping[str, Any]) -> None: if self.kernel_client is None: raise RuntimeError("Kernel client is not initialized") if self.user_input_queue is None: @@ -623,7 +628,7 @@ async def _complete(self, completion_data: Any) -> None: json.dumps({"suggestions": result}).encode("utf8"), ]) - async def complete(self, completion_data) -> Sequence[str]: + async def complete(self, completion_data: Any) -> Sequence[str]: """Return the list of strings to be shown in the auto-complete list. The default interface is jupyter kernel. To use different interface, @@ -728,7 +733,8 @@ async def _send_status(self) -> None: @abstractmethod async def start_service( - self, service_info + self, + service_info: Mapping[str, Any], ) -> ( tuple[list[str] | None, dict[str, str]] | tuple[list[str] | None, dict[str, str], str] @@ -750,7 +756,7 @@ async def start_model_service(self, model_info: Mapping[str, Any]) -> None: # After the None check, narrow the type model_service_info = cast(Mapping[str, Any], model_service_info) service_name = f"{model_info['name']}-{model_service_info['port']}" - self.service_parser.add_model_service(service_name, model_service_info) + self.service_parser.add_model_service(service_name, dict(model_service_info)) service_info = { "name": service_name, "port": model_service_info["port"], @@ -1028,6 +1034,8 @@ async def run_subproc(self, cmd: str | list[str], batch: bool = False) -> int: **pipe_opts, ) self.subproc = proc + if proc.stdout is None or proc.stderr is None: + raise RuntimeError("Process stdout or stderr is None") pipe_tasks = [ loop.create_task( pipe_output(proc.stdout, self.outsock, "stdout", log_out.fileno()) diff --git a/src/ai/backend/kernel/exception.py b/src/ai/backend/kernel/exception.py index 8db72c06765..69983c14442 100644 --- a/src/ai/backend/kernel/exception.py +++ b/src/ai/backend/kernel/exception.py @@ -1,5 +1,5 @@ class MessageError(ValueError): - def __init__(self, message) -> None: + def __init__(self, message: str) -> None: super().__init__(message) self.message = message diff --git a/src/ai/backend/kernel/intrinsic.py b/src/ai/backend/kernel/intrinsic.py index d80eca2267f..f5253a76e89 100644 --- a/src/ai/backend/kernel/intrinsic.py +++ b/src/ai/backend/kernel/intrinsic.py @@ -3,15 +3,16 @@ import logging import os import shutil -from collections.abc import Iterable +from collections.abc import Iterable, Mapping, MutableMapping from pathlib import Path +from typing import Any from .logging import BraceStyleAdapter log = BraceStyleAdapter(logging.getLogger()) -async def init_sshd_service(child_env) -> None: +async def init_sshd_service(child_env: MutableMapping[str, str]) -> None: if Path("/tmp/dropbear").is_dir(): shutil.rmtree("/tmp/dropbear") Path("/tmp/dropbear").mkdir(parents=True, exist_ok=True) @@ -127,7 +128,7 @@ async def init_sshd_service(child_env) -> None: f.write(b"\n") -async def prepare_sshd_service(service_info) -> tuple[list[str], dict]: +async def prepare_sshd_service(service_info: Mapping[str, Any]) -> tuple[list[str], dict[str, str]]: cmdargs = [ "/opt/kernel/dropbearmulti", "dropbear", @@ -153,7 +154,7 @@ async def prepare_sshd_service(service_info) -> tuple[list[str], dict]: return cmdargs, env -async def prepare_ttyd_service(service_info) -> tuple[list[str], dict]: +async def prepare_ttyd_service(service_info: Mapping[str, Any]) -> tuple[list[str], dict[str, str]]: shell = "sh" if Path("/bin/zsh").exists(): shell = "zsh" diff --git a/src/ai/backend/kernel/jupyter_client.py b/src/ai/backend/kernel/jupyter_client.py index 412fead5eff..0bff6cb0ce8 100644 --- a/src/ai/backend/kernel/jupyter_client.py +++ b/src/ai/backend/kernel/jupyter_client.py @@ -1,5 +1,7 @@ import inspect +from collections.abc import Awaitable, Callable, Mapping from time import monotonic +from typing import Any, Optional import zmq import zmq.asyncio @@ -9,14 +11,18 @@ async def aexecute_interactive( kernel_client: AsyncKernelClient, code: str, - silent=False, - store_history=True, - user_expressions=None, - allow_stdin=None, - stop_on_error=True, - timeout=None, - output_hook=None, - stdin_hook=None, + silent: bool = False, + store_history: bool = True, + user_expressions: Optional[Mapping[str, Any]] = None, + allow_stdin: Optional[bool] = None, + stop_on_error: bool = True, + timeout: Optional[float] = None, + output_hook: Callable[[Mapping[str, Any]], Any] + | Callable[[Mapping[str, Any]], Awaitable[Any]] + | None = None, + stdin_hook: Callable[[Mapping[str, Any]], Any] + | Callable[[Mapping[str, Any]], Awaitable[Any]] + | None = None, ) -> dict: """Async version of jupyter_client's execute_interactive method. @@ -36,8 +42,8 @@ async def aexecute_interactive( allow_stdin=allow_stdin, stop_on_error=stop_on_error, ) - stdin_hook = stdin_hook if stdin_hook else kernel_client._stdin_hook_default - output_hook = output_hook if output_hook else kernel_client._output_hook_default + stdin_hook = stdin_hook if stdin_hook else kernel_client._stdin_hook_default # type: ignore[assignment] + output_hook = output_hook if output_hook else kernel_client._output_hook_default # type: ignore[assignment] # set deadline based on timeout if timeout is not None: @@ -64,6 +70,8 @@ async def aexecute_interactive( raise TimeoutError("Timeout waiting for output") if stdin_socket in events: req = await kernel_client.stdin_channel.get_msg(timeout=0) + if stdin_hook is None: + raise RuntimeError("stdin_hook is None") if inspect.iscoroutinefunction(stdin_hook): await stdin_hook(req) else: @@ -77,6 +85,8 @@ async def aexecute_interactive( if msg["parent_header"].get("msg_id") != msg_id: # not from my request continue + if output_hook is None: + raise RuntimeError("output_hook is None") if inspect.iscoroutinefunction(output_hook): await output_hook(msg) else: diff --git a/src/ai/backend/kernel/logging.py b/src/ai/backend/kernel/logging.py index 35f5db385fa..9a8bf8a220a 100644 --- a/src/ai/backend/kernel/logging.py +++ b/src/ai/backend/kernel/logging.py @@ -11,7 +11,7 @@ class RelativeCreatedFormatter(logging.Formatter): - def format(self, record) -> str: + def format(self, record: logging.LogRecord) -> str: record.relative_seconds = record.relativeCreated / 1000 return super().format(record) @@ -19,7 +19,7 @@ def format(self, record) -> str: class BraceMessage: __slots__ = ("args", "fmt") - def __init__(self, fmt, args) -> None: + def __init__(self, fmt: str, args: tuple) -> None: self.fmt = fmt self.args = args @@ -28,7 +28,7 @@ def __str__(self) -> str: class BraceStyleAdapter(logging.LoggerAdapter): - def log(self, level, msg, *args, **kwargs) -> None: + def log(self, level: int, msg: object, *args: object, **kwargs: object) -> None: # type: ignore[override] if self.isEnabledFor(level): _msg, _kwargs = self.process(msg, kwargs) self.logger._log(level, BraceMessage(_msg, args), (), **_kwargs) diff --git a/src/ai/backend/kernel/python/__init__.py b/src/ai/backend/kernel/python/__init__.py index 7a408e81ac6..30b536b6301 100644 --- a/src/ai/backend/kernel/python/__init__.py +++ b/src/ai/backend/kernel/python/__init__.py @@ -3,8 +3,9 @@ import os import shutil import tempfile +from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import janus @@ -88,7 +89,7 @@ async def execute_heuristic(self) -> int: return 127 async def start_service( - self, service_info + self, service_info: Mapping[str, Any] ) -> ( tuple[list[str] | None, dict[str, str]] | tuple[list[str] | None, dict[str, str], str] @@ -129,6 +130,7 @@ async def start_service( ], {} if service_info["name"] == "tensorboard": Path("/home/work/logs").mkdir(parents=True, exist_ok=True) + port_str = str(service_info["port"]) return [ str(self.runtime_path), "-m", @@ -138,7 +140,7 @@ async def start_service( "--host", "0.0.0.0", "--port", - str(service_info["port"]), + port_str, "--debugger_port", "6064", # used by in-container TensorFlow ], {} @@ -154,11 +156,12 @@ async def start_service( "/home/work/spectravis", ) if service_info["name"] == "sftp": + port_str = str(service_info["port"]) return [ str(self.runtime_path), "-m", "sftpserver", "--port", - str(service_info["port"]), + port_str, ], {} return None diff --git a/src/ai/backend/kernel/python/drawing/canvas.py b/src/ai/backend/kernel/python/drawing/canvas.py index f8e69fe989b..d45257c923e 100644 --- a/src/ai/backend/kernel/python/drawing/canvas.py +++ b/src/ai/backend/kernel/python/drawing/canvas.py @@ -6,7 +6,7 @@ from ai.backend.kernel.python.types import MediaRecord -from .color import Colors +from .color import Color, Colors from .encoding import encode_commands from .turtle import Turtle @@ -23,66 +23,84 @@ class DrawingObject: - def __init__(self, canvas, id_, args) -> None: + def __init__(self, canvas: Canvas, id_: int, args: tuple[Any, ...]) -> None: self._canvas = canvas self._id = id_ self._type = args[0] - def set_x(self, x) -> None: + def set_x(self, x: float | int) -> None: if self._type in ("rect", "circle", "triangle"): self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "x", x)) - def set_y(self, y) -> None: + def set_y(self, y: float | int) -> None: if self._type in ("rect", "circle", "triangle"): self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "y", y)) - def set_x1(self, x) -> None: + def set_x1(self, x: float | int) -> None: if self._type == "line": self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "x1", x)) - def set_y1(self, y) -> None: + def set_y1(self, y: float | int) -> None: if self._type == "line": self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "y1", y)) - def set_x2(self, x) -> None: + def set_x2(self, x: float | int) -> None: if self._type == "line": self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "x2", x)) - def set_y2(self, y) -> None: + def set_y2(self, y: float | int) -> None: if self._type == "line": self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "y2", y)) - def set_radius(self, r) -> None: + def set_radius(self, r: float | int) -> None: if self._type == "circle": self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "radius", r)) - def rotate(self, a) -> None: + def rotate(self, a: float | int) -> None: self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "rotate", a)) - def set_angle(self, a) -> None: + def set_angle(self, a: float | int) -> None: self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "angle", a)) - def stroke(self, color) -> None: - color = color.to_hex() + def stroke(self, color: Color) -> None: + color_hex = color.to_hex() if self._type == "line": - self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "color", color)) + self._canvas._cmd_history.append(( + self._canvas._id, + "update", + self._id, + "color", + color_hex, + )) elif self._type == "circle" or self._type in ("rect", "triangle"): self._canvas._cmd_history.append(( self._canvas._id, "update", self._id, "border", - color, + color_hex, )) - def fill(self, color) -> None: - color = color.to_hex() + def fill(self, color: Color) -> None: + color_hex = color.to_hex() if self._type == "circle" or self._type in ("rect", "triangle"): - self._canvas._cmd_history.append((self._canvas._id, "update", self._id, "fill", color)) + self._canvas._cmd_history.append(( + self._canvas._id, + "update", + self._id, + "fill", + color_hex, + )) class Canvas: - def __init__(self, width, height, bgcolor=Colors.White, fgcolor=Colors.Black) -> None: + def __init__( + self, + width: float | int, + height: float | int, + bgcolor: Color = Colors.White, + fgcolor: Color = Colors.Black, + ) -> None: global _canvas_id_counter self._id = _canvas_id_counter _canvas_id_counter += 1 @@ -128,21 +146,28 @@ def begin_group(self) -> None: def end_group(self) -> None: self._cmd_history.append((self._id, "end-group")) - def begin_fill(self, c) -> None: + def begin_fill(self, c: Color) -> None: self._cmd_history.append((self._id, "begin-fill", c.to_hex())) def end_fill(self) -> None: self._cmd_history.append((self._id, "end-fill")) - def background_color(self, c) -> None: + def background_color(self, c: Color) -> None: self.bgcolor = c self._cmd_history.append((self._id, "bgcolor", c.to_hex())) - def stroke_color(self, c) -> None: + def stroke_color(self, c: Color) -> None: self.fgcolor = c self._cmd_history.append((self._id, "fgcolor", c.to_hex())) - def line(self, x0, y0, x1, y1, color=None) -> DrawingObject: + def line( + self, + x0: float | int, + y0: float | int, + x1: float | int, + y1: float | int, + color: Color | None = None, + ) -> DrawingObject: if color is None: color = self.fgcolor args = ( @@ -158,7 +183,15 @@ def line(self, x0, y0, x1, y1, color=None) -> DrawingObject: self._next_objid += 1 return obj - def circle(self, x, y, radius, border=None, fill=None, angle=0) -> DrawingObject: + def circle( + self, + x: float | int, + y: float | int, + radius: float | int, + border: Color | None = None, + fill: Color | None = None, + angle: float | int = 0, + ) -> DrawingObject: if border is None: border = self.fgcolor if fill is None: @@ -177,7 +210,16 @@ def circle(self, x, y, radius, border=None, fill=None, angle=0) -> DrawingObject self._next_objid += 1 return obj - def rectangle(self, left, top, width, height, border=None, fill=None, angle=0) -> DrawingObject: + def rectangle( + self, + left: float | int, + top: float | int, + width: float | int, + height: float | int, + border: Color | None = None, + fill: Color | None = None, + angle: float | int = 0, + ) -> DrawingObject: if border is None: border = self.fgcolor if fill is None: @@ -197,7 +239,16 @@ def rectangle(self, left, top, width, height, border=None, fill=None, angle=0) - self._next_objid += 1 return obj - def triangle(self, left, top, width, height, border=None, fill=None, angle=0) -> DrawingObject: + def triangle( + self, + left: float | int, + top: float | int, + width: float | int, + height: float | int, + border: Color | None = None, + fill: Color | None = None, + angle: float | int = 0, + ) -> DrawingObject: if border is None: border = self.fgcolor if fill is None: diff --git a/src/ai/backend/kernel/python/drawing/color.py b/src/ai/backend/kernel/python/drawing/color.py index 99ff4a42bba..20666a8919e 100644 --- a/src/ai/backend/kernel/python/drawing/color.py +++ b/src/ai/backend/kernel/python/drawing/color.py @@ -8,7 +8,7 @@ class Color: - def __init__(self, red, green, blue, alpha=255) -> None: + def __init__(self, red: int, green: int, blue: int, alpha: int = 255) -> None: self.red = red self.green = green self.blue = blue @@ -32,7 +32,7 @@ def from_bytes(value: bytes) -> Color: r, g, b, a = rgba.unpack(value) return Color(r, g, b, a) - def to_hex(self, include_alpha=True) -> str: + def to_hex(self, include_alpha: bool = True) -> str: if include_alpha: return f"#{self.red:02x}{self.green:02x}{self.blue:02x}{self.alpha:02x}" return f"#{self.red:02x}{self.green:02x}{self.blue:02x}" diff --git a/src/ai/backend/kernel/python/drawing/encoding.py b/src/ai/backend/kernel/python/drawing/encoding.py index 26ad9f85689..890addee95b 100644 --- a/src/ai/backend/kernel/python/drawing/encoding.py +++ b/src/ai/backend/kernel/python/drawing/encoding.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import base64 +from typing import TYPE_CHECKING, Any import msgpack +if TYPE_CHECKING: + from collections.abc import Sequence + -def encode_commands(cmdlist) -> str: +def encode_commands(cmdlist: Sequence[Any]) -> str: bindata = msgpack.packb(cmdlist, use_bin_type=True) return base64.b64encode(bindata).decode("ascii") -def decode_commands(data) -> list: +def decode_commands(data: str | bytes) -> list[Any]: bindata = base64.b64decode(data) return msgpack.unpackb(bindata, raw=False) diff --git a/src/ai/backend/kernel/python/drawing/turtle.py b/src/ai/backend/kernel/python/drawing/turtle.py index 5a4409d81ea..f3c2e79bbf8 100644 --- a/src/ai/backend/kernel/python/drawing/turtle.py +++ b/src/ai/backend/kernel/python/drawing/turtle.py @@ -1,30 +1,34 @@ from __future__ import annotations import math +from typing import TYPE_CHECKING from .color import Colors +if TYPE_CHECKING: + from .canvas import Canvas + class Vec2D(tuple): """A helper class taken from Python stdlib's Turtle package.""" - def __new__(cls, x, y): + def __new__(cls, x: float | int, y: float | int): return tuple.__new__(cls, (x, y)) - def __add__(self, other): + def __add__(self, other: Vec2D): # type: ignore[override] return Vec2D(self[0] + other[0], self[1] + other[1]) - def __mul__(self, other): + def __mul__(self, other: Vec2D | float | int): # type: ignore[override] if isinstance(other, Vec2D): return self[0] * other[0] + self[1] * other[1] return Vec2D(self[0] * other, self[1] * other) - def __rmul__(self, other): + def __rmul__(self, other: float | int): # type: ignore[override] if isinstance(other, (int, float)): return Vec2D(self[0] * other, self[1] * other) return None - def __sub__(self, other): + def __sub__(self, other: Vec2D): # type: ignore[override] return Vec2D(self[0] - other[0], self[1] - other[1]) def __neg__(self): @@ -33,7 +37,7 @@ def __neg__(self): def __abs__(self): return (self[0] ** 2 + self[1] ** 2) ** 0.5 - def rotate(self, angle) -> Vec2D: + def rotate(self, angle: float | int) -> Vec2D: """rotate self counterclockwise by angle""" perp = Vec2D(-self[1], self[0]) angle = angle * math.pi / 180.0 @@ -48,7 +52,7 @@ def __repr__(self) -> str: class Turtle: - def __init__(self, canvas) -> None: + def __init__(self, canvas: Canvas) -> None: self.canvas = canvas self.points = [] self.pen = True @@ -66,7 +70,7 @@ def __init__(self, canvas) -> None: self.angle = 90 self.points.append((w / 2, h / 2)) - def forward(self, amt) -> None: + def forward(self, amt: float | int) -> None: x = self.points[-1][0] y = self.points[-1][1] x_diff = math.sin(math.radians(self.angle)) * amt @@ -79,13 +83,13 @@ def forward(self, amt) -> None: self.canvas.end_group() self.points.append((x + x_diff, y + y_diff)) - def left(self, deg) -> None: + def left(self, deg: float | int) -> None: self.cursor.rotate(-deg) - self.angle -= deg + self.angle -= deg # type: ignore[assignment] - def right(self, deg) -> None: + def right(self, deg: float | int) -> None: self.cursor.rotate(deg) - self.angle += deg + self.angle += deg # type: ignore[assignment] def pos(self) -> Vec2D: base_x, base_y = self.points[0][0], self.points[0][1] @@ -97,25 +101,25 @@ def penup(self) -> None: def pendown(self) -> None: self.pen = True - def setpos(self, x, y=None) -> None: + def setpos(self, x: float | int | Vec2D, y: float | int | None = None) -> None: base_x, base_y = self.points[0][0], self.points[0][1] if y is None: - _x = x[0] - _y = x[1] + _x = x[0] # type: ignore[index] + _y = x[1] # type: ignore[index] x, y = _x, _y self.canvas.begin_group() if self.pen: self.canvas.line( self.points[-1][0], self.points[-1][1], - x + base_x, - y + base_y, + x + base_x, # type: ignore[operator] + y + base_y, # type: ignore[operator] color=Colors.from_rgba([255, 0, 0, 128]), ) - self.cursor.set_x(x + base_x) - self.cursor.set_y(y + base_y) + self.cursor.set_x(x + base_x) # type: ignore[operator] + self.cursor.set_y(y + base_y) # type: ignore[operator] self.canvas.end_group() - self.points.append((x + base_x, y + base_y)) + self.points.append((x + base_x, y + base_y)) # type: ignore[operator] __all__ = [ diff --git a/src/ai/backend/kernel/python/sitecustomize.py b/src/ai/backend/kernel/python/sitecustomize.py index 0c5b9e33688..ab1486fef9c 100644 --- a/src/ai/backend/kernel/python/sitecustomize.py +++ b/src/ai/backend/kernel/python/sitecustomize.py @@ -11,8 +11,8 @@ if sys.version_info.major > 2: import builtins - def _input(prompt="") -> str: - sys.stdout.write(prompt) + def _input(prompt: object = "") -> str: + sys.stdout.write(str(prompt)) sys.stdout.flush() with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: try: @@ -31,8 +31,8 @@ def _input(prompt="") -> str: builtins = __builtin__ - def _raw_input(prompt="") -> str: - sys.stdout.write(prompt) + def _raw_input(prompt: object = "") -> str: + sys.stdout.write(str(prompt)) sys.stdout.flush() try: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/src/ai/backend/kernel/service.py b/src/ai/backend/kernel/service.py index b522ae268c0..9a7b78d93a9 100644 --- a/src/ai/backend/kernel/service.py +++ b/src/ai/backend/kernel/service.py @@ -84,7 +84,7 @@ async def parse(self, path: Path) -> None: except TypeError as e: raise InvalidServiceDefinition(e.args[0][11:]) from e # lstrip "__init__() " - def add_model_service(self, name, model_service_info) -> None: + def add_model_service(self, name: str, model_service_info: dict[str, Any]) -> None: service_def = ServiceDefinition( model_service_info["start_command"], shell=model_service_info["shell"], diff --git a/src/ai/backend/kernel/service_actions.py b/src/ai/backend/kernel/service_actions.py index 98cabe8a5f6..7a70950d298 100644 --- a/src/ai/backend/kernel/service_actions.py +++ b/src/ai/backend/kernel/service_actions.py @@ -41,7 +41,7 @@ async def write_tempfile( async def run_command( variables: Mapping[str, Any], command: Iterable[str], - echo=False, + echo: bool = False, ) -> Optional[MutableMapping[str, str]]: proc = await create_subprocess_exec( *(str(piece).format_map(variables) for piece in command), diff --git a/src/ai/backend/kernel/terminal.py b/src/ai/backend/kernel/terminal.py index 8bd5c267b35..d5abf3fbcf1 100644 --- a/src/ai/backend/kernel/terminal.py +++ b/src/ai/backend/kernel/terminal.py @@ -21,7 +21,7 @@ from .utils import safe_close_task if TYPE_CHECKING: - from asyncio import Task + from asyncio import AbstractEventLoop, Task log = BraceStyleAdapter(logging.getLogger()) @@ -31,7 +31,15 @@ class Terminal: A wrapper for a terminal-based app. """ - def __init__(self, shell_cmd, ev_term, sock_out, *, auto_restart=True, loop=None) -> None: + def __init__( + self, + shell_cmd: list[str], + ev_term: asyncio.Event, + sock_out: zmq.asyncio.Socket, + *, + auto_restart: bool = True, + loop: AbstractEventLoop | None = None, + ) -> None: self._sorna_media: list = [] self.zctx = sock_out.context @@ -65,11 +73,11 @@ def __init__(self, shell_cmd, ev_term, sock_out, *, auto_restart=True, loop=None parser_resize.add_argument("cols", type=int) parser_resize.set_defaults(func=self.do_resize_term) - async def do_ping(self, args) -> int: + async def do_ping(self, args: argparse.Namespace) -> int: await self.sock_out.send_multipart([b"stdout", b"pong!"]) return 0 - async def do_resize_term(self, args) -> int: + async def do_resize_term(self, args: argparse.Namespace) -> int: if self.fd is None: return 0 origsz_in = struct.pack("HHHH", 0, 0, 0, 0) @@ -84,7 +92,7 @@ async def do_resize_term(self, args) -> int: ]) return 0 - async def handle_command(self, code_txt) -> int: + async def handle_command(self, code_txt: str) -> int: try: if code_txt.startswith("%"): args = self.cmdparser.parse_args(shlex.split(code_txt[1:], comments=True)) @@ -108,7 +116,9 @@ async def start(self) -> None: await safe_close_task(self.term_out_task) pid, fd = pty.fork() if pid == 0: - args = shlex.split(self.shell_cmd) + args = ( + shlex.split(self.shell_cmd) if isinstance(self.shell_cmd, str) else self.shell_cmd + ) os.execv(args[0], args) else: self.pid = pid @@ -156,7 +166,7 @@ async def restart(self) -> None: except Exception: log.exception("Unexpected error during restart of terminal") - async def term_in(self, term_writer) -> None: + async def term_in(self, term_writer: asyncio.StreamWriter) -> None: try: if self.sock_term_in is None: raise RuntimeError("Terminal input socket is not initialized") @@ -175,7 +185,7 @@ async def term_in(self, term_writer) -> None: except Exception: log.exception("Unexpected error at term_in()") - async def term_out(self, term_reader) -> None: + async def term_out(self, term_reader: asyncio.StreamReader) -> None: try: if self.sock_term_out is None: raise RuntimeError("Terminal output socket is not initialized") diff --git a/src/ai/backend/kernel/utils.py b/src/ai/backend/kernel/utils.py index 1ba18c076fd..d398614d853 100644 --- a/src/ai/backend/kernel/utils.py +++ b/src/ai/backend/kernel/utils.py @@ -22,7 +22,7 @@ CLOCK_TICK: Final = os.sysconf("SC_CLK_TCK") -def find_executable(*paths) -> Path | None: +def find_executable(*paths: Path | str | bytes) -> Path | None: """ Find the first executable regular file in the given list of paths. """ @@ -60,13 +60,13 @@ def filter(self, record: logging.LogRecord) -> bool: return True -async def safe_close_task(task) -> None: +async def safe_close_task(task: asyncio.Task[Any] | None) -> None: if task is not None and not task.done(): task.cancel() await task -async def wait_local_port_open(port) -> None: +async def wait_local_port_open(port: int) -> None: while True: try: async with asyncio.timeout(10.0): @@ -98,7 +98,7 @@ def scan_proc_stats() -> dict[int, dict]: return pid_set -def parse_proc_stat(pid) -> dict[str, Any]: +def parse_proc_stat(pid: int) -> dict[str, Any]: data = Path(f"/proc/{pid}/stat").read_bytes() name_begin = data.find(b"(") name_end = data.rfind(b")") diff --git a/src/ai/backend/kernel/vendor/aws_polly/__init__.py b/src/ai/backend/kernel/vendor/aws_polly/__init__.py index 7e3de80c5a3..3d15b1a2fd1 100644 --- a/src/ai/backend/kernel/vendor/aws_polly/__init__.py +++ b/src/ai/backend/kernel/vendor/aws_polly/__init__.py @@ -3,6 +3,8 @@ import logging import os import threading +from collections.abc import Mapping +from typing import Any import janus @@ -40,7 +42,7 @@ async def build_heuristic(self) -> int: async def execute_heuristic(self) -> int: raise NotImplementedError - async def query(self, code_text) -> int: + async def query(self, code_text: str) -> int: self.ensure_inproc_runner() if self.input_queue is None: raise RuntimeError("Input queue is not initialized") @@ -59,7 +61,7 @@ async def query(self, code_text) -> int: self.outsock.send_multipart(msg) return 0 - async def complete(self, data) -> None: + async def complete(self, data: Any) -> None: self.outsock.send_multipart([ b"completion", [], @@ -83,7 +85,7 @@ async def interrupt(self) -> None: ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(target_tid), ctypes.c_long(0)) log.error("Interrupt broke the interpreter state -- recommended to reset the session.") - async def start_service(self, service_info) -> tuple[None, dict]: + async def start_service(self, service_info: Mapping[str, Any]) -> tuple[None, dict]: return None, {} def ensure_inproc_runner(self) -> None: diff --git a/src/ai/backend/kernel/vendor/aws_polly/inproc.py b/src/ai/backend/kernel/vendor/aws_polly/inproc.py index 2f077ef6e12..efe5fd6712a 100644 --- a/src/ai/backend/kernel/vendor/aws_polly/inproc.py +++ b/src/ai/backend/kernel/vendor/aws_polly/inproc.py @@ -3,6 +3,8 @@ import json import logging import threading +from queue import Queue +from typing import Any from boto3 import Session from botocore.exceptions import BotoCoreError, ClientError # pants: no-infer-dep @@ -11,7 +13,14 @@ class PollyInprocRunner(threading.Thread): - def __init__(self, input_queue, output_queue, sentinel, access_key, secret_key) -> None: + def __init__( + self, + input_queue: Queue[str], + output_queue: Queue[Any], + sentinel: object, + access_key: str | None, + secret_key: str | None, + ) -> None: super().__init__(name="InprocRunner", daemon=True) # for interoperability with the main asyncio loop diff --git a/src/ai/backend/kernel/vendor/h2o/__init__.py b/src/ai/backend/kernel/vendor/h2o/__init__.py index 7fd9737a01d..eab7a6ef31a 100644 --- a/src/ai/backend/kernel/vendor/h2o/__init__.py +++ b/src/ai/backend/kernel/vendor/h2o/__init__.py @@ -1,7 +1,9 @@ import asyncio import logging import tempfile +from collections.abc import Mapping from pathlib import Path +from typing import Any from ai.backend.kernel import BaseRunner @@ -53,7 +55,7 @@ async def execute_heuristic(self) -> int: log.error('cannot find the main script ("main.py").') return 127 - async def start_service(self, service_info) -> tuple[list, dict] | None: + async def start_service(self, service_info: Mapping[str, Any]) -> tuple[list, dict] | None: if service_info["name"] in ["jupyter", "jupyterlab"]: with tempfile.NamedTemporaryFile( "w", encoding="utf-8", suffix=".py", delete=False diff --git a/src/ai/backend/logging/config.py b/src/ai/backend/logging/config.py index 9dd49c04006..6daea6ffc8d 100644 --- a/src/ai/backend/logging/config.py +++ b/src/ai/backend/logging/config.py @@ -291,7 +291,10 @@ class LogHandlerConfig(BaseConfigModel): ) @model_serializer(mode="wrap") - def rename_class(self, handler) -> dict[str, Any]: + def rename_class( + self, + handler: Any, + ) -> dict[str, Any]: data = handler(self) if "class_" in data: data["class"] = data.pop("class_") diff --git a/src/ai/backend/logging/formatter.py b/src/ai/backend/logging/formatter.py index 8aeddee7877..9f374343bf7 100644 --- a/src/ai/backend/logging/formatter.py +++ b/src/ai/backend/logging/formatter.py @@ -17,7 +17,10 @@ ) -def format_exception(self, ei: Sequence[str] | _SysExcInfoType) -> str: +def format_exception( + self: logging.Formatter, + ei: Sequence[str] | _SysExcInfoType, +) -> str: match ei: case (str(), *_): # Already foramtted from the source process for ease of serialization @@ -31,12 +34,12 @@ def format_exception(self, ei: Sequence[str] | _SysExcInfoType) -> str: class SerializedExceptionFormatter(logging.Formatter): - def formatException(self, ei) -> str: + def formatException(self, ei: Sequence[str] | _SysExcInfoType) -> str: return format_exception(self, ei) class ConsoleFormatter(logging.Formatter): - def formatException(self, ei) -> str: + def formatException(self, ei: Sequence[str] | _SysExcInfoType) -> str: return format_exception(self, ei) def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: @@ -49,7 +52,7 @@ def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> s class CustomJsonFormatter(JsonFormatter): - def formatException(self, ei) -> str: + def formatException(self, ei: Sequence[str] | _SysExcInfoType) -> str: return format_exception(self, ei) def add_fields( diff --git a/src/ai/backend/logging/handler/graylog.py b/src/ai/backend/logging/handler/graylog.py index 83b602fa895..b3c969d8ec3 100644 --- a/src/ai/backend/logging/handler/graylog.py +++ b/src/ai/backend/logging/handler/graylog.py @@ -3,6 +3,7 @@ import logging import socket import ssl +from typing import Any import graypy @@ -12,7 +13,16 @@ class GELFTLSHandler(graypy.GELFTLSHandler): ssl_ctx: ssl.SSLContext - def __init__(self, host, port=12204, validate=False, ca_certs=None, **kwargs) -> None: + def __init__( + self, + host: str, + port: int = 12204, + validate: bool = False, + ca_certs: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + **kwargs, + ) -> None: """Initialize the GELFTLSHandler :param host: GELF TLS input host. @@ -56,7 +66,7 @@ def setup_graylog_handler(config: LoggingConfig) -> logging.Handler: drv_config = config.graylog if drv_config is None: raise RuntimeError("Graylog configuration is required but not provided") - graylog_params = { + graylog_params: dict[str, Any] = { "host": drv_config.host, "port": drv_config.port, "validate": drv_config.ssl_verify, diff --git a/src/ai/backend/manager/api/acl.py b/src/ai/backend/manager/api/acl.py index 56374161abd..40b6ac0ff44 100644 --- a/src/ai/backend/manager/api/acl.py +++ b/src/ai/backend/manager/api/acl.py @@ -11,6 +11,7 @@ from .auth import auth_required from .gql_legacy.acl import get_all_permissions from .manager import ALL_ALLOWED, server_status_required +from .types import CORSOptions log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -32,7 +33,7 @@ async def shutdown(app: web.Application) -> None: pass -def create_app(default_cors_options) -> tuple[web.Application, list]: +def create_app(default_cors_options: CORSOptions) -> tuple[web.Application, list]: app = web.Application() app["prefix"] = "acl" app["api_versions"] = (4,) diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index 8c1a86fc61e..a1b9c00b5a7 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -2,7 +2,7 @@ import logging import traceback -from collections.abc import Iterable +from collections.abc import Callable, Iterable from http import HTTPStatus from typing import TYPE_CHECKING, Any, cast @@ -51,7 +51,7 @@ class CustomGraphQLView(GraphQLView): """Custom GraphQL view for Backend.AI with OpenAPI compatibility.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.__name__ = "handle_graphql_strawberry" self.__doc__ = """ @@ -84,7 +84,7 @@ async def get_context( # type: ignore[override] class GQLLoggingMiddleware: - def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + def resolve(self, next: Callable, root: Any, info: graphene.ResolveInfo, **args: Any) -> Any: if info.path.prev is None: # indicates the root query graph_ctx = info.context log.info( @@ -98,7 +98,7 @@ def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: class CustomIntrospectionRule(ValidationRule): - def enter_field(self, node: FieldNode, *_args) -> None: + def enter_field(self, node: FieldNode, *_args: Any) -> None: field_name = node.name.value if field_name.startswith("__"): # Allow __typename field for GraphQL Federation, @connection directive diff --git a/src/ai/backend/manager/api/auth.py b/src/ai/backend/manager/api/auth.py index da9a44325c2..4107d5d7d10 100644 --- a/src/ai/backend/manager/api/auth.py +++ b/src/ai/backend/manager/api/auth.py @@ -301,7 +301,7 @@ whois_timezone_info: Final[Mapping[str, int]] = {k: int(v) for k, v in _whois_timezone_info.items()} -def _extract_auth_params(request) -> tuple[str, str, str] | None: +def _extract_auth_params(request: web.Request) -> tuple[str, str, str] | None: """ HTTP Authorization header must be formatted as: "Authorization: BackendAI signMethod=HMAC-SHA256, @@ -717,7 +717,7 @@ def _setup_user_context(request: web.Request) -> ExitStack: @web.middleware -async def auth_middleware(request: web.Request, handler) -> web.StreamResponse: +async def auth_middleware(request: web.Request, handler: Handler) -> web.StreamResponse: """ Unified authentication middleware - routes to appropriate authentication flow. diff --git a/src/ai/backend/manager/api/etcd.py b/src/ai/backend/manager/api/etcd.py index 4aa440e81ea..494a40d042a 100644 --- a/src/ai/backend/manager/api/etcd.py +++ b/src/ai/backend/manager/api/etcd.py @@ -211,7 +211,7 @@ async def set_config(request: web.Request, params: Any) -> web.Response: if isinstance(params["value"], Mapping): updates = {} - def flatten(prefix, o) -> None: + def flatten(prefix: str, o: Mapping[str, Any]) -> None: for k, v in o.items(): inner_prefix = prefix if k == "" else f"{prefix}/{k}" if isinstance(v, Mapping): diff --git a/src/ai/backend/manager/api/gql/base.py b/src/ai/backend/manager/api/gql/base.py index ab4026fb84f..437559ca8c9 100644 --- a/src/ai/backend/manager/api/gql/base.py +++ b/src/ai/backend/manager/api/gql/base.py @@ -10,6 +10,7 @@ import graphene import strawberry from graphql import StringValueNode +from graphql.language.ast import ValueNode from graphql_relay.utils import base64, unbase64 from strawberry.types import get_object_definition, has_object_definition @@ -58,7 +59,7 @@ def parse_value(value: str) -> str: return value @staticmethod - def parse_literal(ast) -> str: + def parse_literal(ast: ValueNode) -> str: if not isinstance(ast, StringValueNode): raise ValueError("ByteSize must be provided as a string literal") return ast.value diff --git a/src/ai/backend/manager/api/gql/vfolder.py b/src/ai/backend/manager/api/gql/vfolder.py index 63f6754d50a..b0d0a27dece 100644 --- a/src/ai/backend/manager/api/gql/vfolder.py +++ b/src/ai/backend/manager/api/gql/vfolder.py @@ -27,7 +27,7 @@ class ExtraVFolderMount(Node): @strawberry.field async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder: - vfolder_global_id = AsyncNode.to_global_id("VirtualFolderNode", self._vfolder_id) + vfolder_global_id = AsyncNode.to_global_id("VirtualFolderNode", str(self._vfolder_id)) return VFolder(id=ID(vfolder_global_id)) @classmethod diff --git a/src/ai/backend/manager/api/gql_legacy/agent.py b/src/ai/backend/manager/api/gql_legacy/agent.py index a98ceee27a6..79a25febd89 100644 --- a/src/ai/backend/manager/api/gql_legacy/agent.py +++ b/src/ai/backend/manager/api/gql_legacy/agent.py @@ -289,7 +289,7 @@ async def get_connection( cnt_query = cnt_query.where(cond) else: - async def all_permissions(row) -> frozenset[AgentPermission]: + async def all_permissions(row: AgentRow) -> frozenset[AgentPermission]: return ADMIN_PERMISSIONS permission_getter = all_permissions # type: ignore[assignment] @@ -868,7 +868,7 @@ class Arguments: ) async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, props: ModifyAgentInput, @@ -906,7 +906,7 @@ class Arguments: ) async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, agent_id: str, ) -> RescanGPUAllocMaps: diff --git a/src/ai/backend/manager/api/gql_legacy/audit_log.py b/src/ai/backend/manager/api/gql_legacy/audit_log.py index e7d99b15222..df07ea30719 100644 --- a/src/ai/backend/manager/api/gql_legacy/audit_log.py +++ b/src/ai/backend/manager/api/gql_legacy/audit_log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping from datetime import timedelta from typing import TYPE_CHECKING, Optional, Self, cast @@ -99,7 +101,7 @@ class Meta: } @classmethod - def from_row(cls, ctx, row: AuditLogRow) -> Self: + def from_row(cls, ctx: GraphQueryContext, row: AuditLogRow) -> Self: return cls( id=row.id, row_id=row.id, diff --git a/src/ai/backend/manager/api/gql_legacy/base.py b/src/ai/backend/manager/api/gql_legacy/base.py index 7fe0548a473..6fb0c6dcbb7 100644 --- a/src/ai/backend/manager/api/gql_legacy/base.py +++ b/src/ai/backend/manager/api/gql_legacy/base.py @@ -53,6 +53,7 @@ QueryOrderParser, ) from ai.backend.manager.models.minilang.queryfilter import QueryFilterParser, WhereClauseType +from ai.backend.manager.models.user import UserRole from ai.backend.manager.models.utils import execute_with_retry from .gql_relay import ( @@ -65,8 +66,6 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.selectable import ScalarSelect - from ai.backend.manager.models.user import UserRole - from .schema import GraphQueryContext log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -146,7 +145,7 @@ def serialize(val: bytes) -> str: return val.hex() @staticmethod - def parse_literal(node: Any, _variables=None) -> Optional[bytes]: + def parse_literal(node: Any, _variables: Optional[dict[str, Any]] = None) -> Optional[bytes]: if isinstance(node, graphql.language.ast.StringValueNode): return bytes.fromhex(node.value) return None @@ -186,7 +185,9 @@ def serialize(value: Any) -> dict[str, float]: return validated @classmethod - def parse_literal(cls, node: ValueNode, _variables=None) -> dict[str, float]: + def parse_literal( + cls, node: ValueNode, _variables: Optional[dict[str, Any]] = None + ) -> dict[str, float]: if not isinstance(node, ObjectValueNode): raise GraphQLError(f"UUIDFloatMap cannot represent non-object value: {print_ast(node)}") validated: dict[str, Any] = {} @@ -291,7 +292,7 @@ def _get_type_cache(cls) -> dict[str, Any]: return cls._type_cache @staticmethod - def _get_key(otname: str, args, kwargs) -> int: + def _get_key(otname: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> int: """ Calculate the hash of the all arguments and keyword arguments. """ @@ -524,7 +525,7 @@ async def batch_multiresult_in_scalar_stream( def privileged_query(required_role: UserRole) -> Callable: - def wrap(func) -> Callable: + def wrap(func: Callable) -> Callable: @functools.wraps(func) async def wrapped( root: Any, @@ -560,7 +561,7 @@ def scoped_query( in the keyword arguments. """ - def wrap(resolve_func) -> Callable: + def wrap(resolve_func: Callable) -> Callable: @functools.wraps(resolve_func) async def wrapped( root: Any, @@ -622,10 +623,10 @@ async def wrapped( return wrap -def privileged_mutation(required_role, target_func=None) -> Callable: - def wrap(func) -> Callable: +def privileged_mutation(required_role: UserRole, target_func: Callable | None = None) -> Callable: + def wrap(func: Callable) -> Callable: @functools.wraps(func) - async def wrapped(cls, root, info: graphene.ResolveInfo, *args, **kwargs) -> Any: + async def wrapped(cls: type, root: Any, info: graphene.ResolveInfo, *args, **kwargs) -> Any: from ai.backend.manager.models.group import groups # , association_groups_users from ai.backend.manager.models.user import UserRole @@ -809,7 +810,7 @@ def set_if_set( target: MutableMapping[str, Any], name: str, *, - clean_func=None, + clean_func: Callable[[Any], Any] | None = None, target_key: Optional[str] = None, ) -> None: """ @@ -835,7 +836,7 @@ def orm_set_if_set( target: MutableMapping[str, Any], name: str, *, - clean_func=None, + clean_func: Callable[[Any], Any] | None = None, target_key: Optional[str] = None, ) -> None: """ @@ -870,11 +871,11 @@ class InferenceSessionErrorInfo(graphene.ObjectType): class AsyncPaginatedConnectionField(AsyncListConnectionField): - def __init__(self, type, *args, **kwargs) -> None: + def __init__(self, type: type | str, *args, **kwargs) -> None: kwargs.setdefault("filter", graphene.String()) kwargs.setdefault("order", graphene.String()) kwargs.setdefault("offset", graphene.Int()) - super().__init__(type, *args, **kwargs) + super().__init__(type, *args, **kwargs) # type: ignore[arg-type] PaginatedConnectionField = AsyncPaginatedConnectionField @@ -943,7 +944,7 @@ class _StmtWithConditions: def _apply_ordering( stmt: sa.sql.Select, - id_column: sa.Column | InstrumentedAttribute, + id_column: sa.Column[Any] | InstrumentedAttribute[Any], ordering_item_list: list[OrderingItem], pagination_order: ConnectionPaginationOrder | None, ) -> sa.sql.Select: @@ -956,16 +957,27 @@ def _apply_ordering( case ConnectionPaginationOrder.FORWARD | None: # Default ordering by id column (ascending for forward pagination) id_ordering_item = OrderingItem(id_column, OrderDirection.ASC) - set_ordering = lambda col, direction: ( - col.asc() if direction == OrderDirection.ASC else col.desc() - ) + + def set_ordering( + col: sa.Column[Any] + | InstrumentedAttribute[Any] + | sa.sql.elements.KeyedColumnElement[Any], + direction: OrderDirection, + ) -> Any: + return col.asc() if direction == OrderDirection.ASC else col.desc() + case ConnectionPaginationOrder.BACKWARD: # Default ordering by id column (descending for backward pagination) id_ordering_item = OrderingItem(id_column, OrderDirection.DESC) + # Reverse ordering direction for backward pagination - set_ordering = lambda col, direction: ( - col.desc() if direction == OrderDirection.ASC else col.asc() - ) + def set_ordering( + col: sa.Column[Any] + | InstrumentedAttribute[Any] + | sa.sql.elements.KeyedColumnElement[Any], + direction: OrderDirection, + ) -> Any: + return col.desc() if direction == OrderDirection.ASC else col.asc() # Apply ordering to stmt (id column should be applied last for deterministic ordering) for col, direction in [*ordering_item_list, id_ordering_item]: @@ -976,7 +988,7 @@ def _apply_ordering( def _apply_filter_conditions( stmt: sa.sql.Select, - orm_class, + orm_class: type[Any], filter_expr: FilterExprArg, ) -> _StmtWithConditions: """ @@ -995,7 +1007,7 @@ def _apply_filter_conditions( def _apply_cursor_pagination( info: graphene.ResolveInfo, stmt: sa.sql.Select, - id_column: sa.Column | InstrumentedAttribute, + id_column: sa.Column[Any] | InstrumentedAttribute[Any], ordering_item_list: list[OrderingItem], cursor_id: str, pagination_order: ConnectionPaginationOrder | None, @@ -1014,8 +1026,10 @@ def _apply_cursor_pagination( cursor_row_id = cursor_row_id_str def subq_to_condition( - column_to_be_compared: sa.Column | InstrumentedAttribute, - subquery: ScalarSelect, + column_to_be_compared: sa.Column[Any] + | InstrumentedAttribute[Any] + | sa.sql.elements.KeyedColumnElement[Any], + subquery: ScalarSelect[Any], direction: OrderDirection, ) -> WhereClauseType: """Generate cursor condition for a specific ordering column. @@ -1071,8 +1085,8 @@ def subq_to_condition( def _build_sql_stmt_from_connection_args( info: graphene.ResolveInfo, - orm_class, - id_column: sa.Column | InstrumentedAttribute, + orm_class: type[Any], + id_column: sa.Column[Any] | InstrumentedAttribute[Any], filter_expr: FilterExprArg | None = None, order_expr: OrderExprArg | None = None, *, @@ -1118,8 +1132,8 @@ def _build_sql_stmt_from_connection_args( def _build_sql_stmt_from_sql_arg( info: graphene.ResolveInfo, - orm_class, - id_column: sa.Column | InstrumentedAttribute, + orm_class: type[Any], + id_column: sa.Column[Any] | InstrumentedAttribute[Any], filter_expr: FilterExprArg | None = None, order_expr: OrderExprArg | None = None, *, @@ -1173,7 +1187,7 @@ class OrderExprArg(NamedTuple): def generate_sql_info_for_gql_connection( info: graphene.ResolveInfo, - orm_class, + orm_class: type[Any], id_column: sa.Column[Any] | InstrumentedAttribute[Any], filter_expr: FilterExprArg | None = None, order_expr: OrderExprArg | None = None, diff --git a/src/ai/backend/manager/api/gql_legacy/container_registry.py b/src/ai/backend/manager/api/gql_legacy/container_registry.py index d6e975dfa4b..f8649886d7b 100644 --- a/src/ai/backend/manager/api/gql_legacy/container_registry.py +++ b/src/ai/backend/manager/api/gql_legacy/container_registry.py @@ -335,7 +335,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, url: str, type: ContainerRegistryType, @@ -345,7 +345,7 @@ async def mutate( username: str | UndefinedType = Undefined, password: str | UndefinedType = Undefined, ssl_verify: bool | UndefinedType = Undefined, - extra: dict | UndefinedType = Undefined, + extra: dict[str, Any] | UndefinedType = Undefined, ) -> CreateContainerRegistryNode: ctx: GraphQueryContext = info.context validator = ContainerRegistryValidator( @@ -417,7 +417,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, url: str | UndefinedType = Undefined, @@ -428,7 +428,7 @@ async def mutate( username: str | UndefinedType = Undefined, password: str | UndefinedType = Undefined, ssl_verify: bool | UndefinedType = Undefined, - extra: dict | UndefinedType = Undefined, + extra: dict[str, Any] | UndefinedType = Undefined, ) -> ModifyContainerRegistryNode: ctx: GraphQueryContext = info.context @@ -483,7 +483,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, ) -> DeleteContainerRegistryNode: @@ -520,7 +520,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scope_id: ScopeType, quota: int | float, @@ -560,7 +560,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scope_id: ScopeType, quota: int | float, @@ -599,7 +599,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scope_id: ScopeType, ) -> Self: @@ -732,7 +732,11 @@ class Arguments: @classmethod async def mutate( - cls, root, info: graphene.ResolveInfo, hostname: str, props: CreateContainerRegistryInput + cls, + root: Any, + info: graphene.ResolveInfo, + hostname: str, + props: CreateContainerRegistryInput, ) -> CreateContainerRegistry: ctx: GraphQueryContext = info.context @@ -776,7 +780,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, hostname: str, props: ModifyContainerRegistryInput, @@ -827,7 +831,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, hostname: str, ) -> DeleteContainerRegistry: diff --git a/src/ai/backend/manager/api/gql_legacy/container_registry_v2.py b/src/ai/backend/manager/api/gql_legacy/container_registry_v2.py index 1da2f6de7ed..620c0592a9d 100644 --- a/src/ai/backend/manager/api/gql_legacy/container_registry_v2.py +++ b/src/ai/backend/manager/api/gql_legacy/container_registry_v2.py @@ -77,7 +77,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, props: CreateContainerRegistryNodeInputV2, ) -> CreateContainerRegistryNodeV2: @@ -193,7 +193,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, props: ModifyContainerRegistryNodeInputV2, @@ -228,7 +228,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, ) -> DeleteContainerRegistryNodeV2: diff --git a/src/ai/backend/manager/api/gql_legacy/domain.py b/src/ai/backend/manager/api/gql_legacy/domain.py index c9a0ca701a7..be451ae0618 100644 --- a/src/ai/backend/manager/api/gql_legacy/domain.py +++ b/src/ai/backend/manager/api/gql_legacy/domain.py @@ -327,7 +327,7 @@ async def get_connection( ) return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) - async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> DomainNode: + async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs: Any) -> DomainNode: domain_node = await DomainNode.get_node(info, self.id) if domain_node is None: raise DomainNotFound(f"Domain not found: {self.id}") @@ -377,7 +377,7 @@ class Meta: scaling_groups = graphene.List(lambda: graphene.String, required=False) def to_action(self, user_info: UserInfo) -> CreateDomainNodeAction: - def value_or_none(value) -> Any: + def value_or_none(value: Any) -> Any: return value if value is not graphql.Undefined else None return CreateDomainNodeAction( @@ -642,7 +642,7 @@ class DomainInput(graphene.InputObjectType): integration_id = graphene.String(required=False, default_value=None) def to_action(self, domain_name: str, user_info: UserInfo) -> CreateDomainAction: - def value_or_none(value) -> Any: + def value_or_none(value: Any) -> Any: return value if value is not Undefined else None return CreateDomainAction( @@ -721,7 +721,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: DomainInput, @@ -757,7 +757,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: ModifyDomainInput, @@ -793,7 +793,7 @@ class Arguments: msg = graphene.String() @classmethod - async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> DeleteDomain: + async def mutate(cls, root: Any, info: graphene.ResolveInfo, name: str) -> DeleteDomain: ctx: GraphQueryContext = info.context user_info: UserInfo = UserInfo( @@ -824,7 +824,7 @@ class Arguments: msg = graphene.String() @classmethod - async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> PurgeDomain: + async def mutate(cls, root: Any, info: graphene.ResolveInfo, name: str) -> PurgeDomain: ctx: GraphQueryContext = info.context user_info: UserInfo = UserInfo( diff --git a/src/ai/backend/manager/api/gql_legacy/endpoint.py b/src/ai/backend/manager/api/gql_legacy/endpoint.py index 2ed92bc1afa..d29286677b1 100644 --- a/src/ai/backend/manager/api/gql_legacy/endpoint.py +++ b/src/ai/backend/manager/api/gql_legacy/endpoint.py @@ -428,7 +428,7 @@ class Meta: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, endpoint: str, props: EndpointAutoScalingRuleInput, @@ -476,7 +476,7 @@ class Meta: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, props: ModifyEndpointAutoScalingRuleInput, @@ -522,7 +522,7 @@ class Meta: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: str, ) -> Self: @@ -659,7 +659,7 @@ class Meta: @classmethod async def from_row( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, row: EndpointRow, ) -> Self: creator = cast(Optional[UserRow], row.created_user_row) @@ -699,13 +699,13 @@ async def from_row( created_at=row.created_at, destroyed_at=row.destroyed_at, retries=row.retries, - routings=[await Routing.from_row(None, r, endpoint=row) for r in row.routings], + routings=[await Routing.from_row(ctx, r, endpoint=row) for r in row.routings], lifecycle_stage=row.lifecycle_stage.name, runtime_variant=RuntimeVariantInfo.from_enum(row.runtime_variant), ) @classmethod - def from_dto(cls, ctx, dto: Optional[EndpointData]) -> Optional[Self]: + def from_dto(cls, ctx: GraphQueryContext, dto: Optional[EndpointData]) -> Optional[Self]: if dto is None: return None return cls( @@ -749,7 +749,7 @@ def from_dto(cls, ctx, dto: Optional[EndpointData]) -> Optional[Self]: @classmethod async def load_count( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, project: UUID | None = None, domain_name: Optional[str] = None, @@ -776,12 +776,12 @@ async def load_count( async with ctx.db.begin_readonly() as conn: result = await conn.execute(query) - return result.scalar() + return result.scalar() or 0 @classmethod async def load_slice( cls, - ctx, #: GraphQueryContext, # ctx: GraphQueryContext, + ctx: GraphQueryContext, limit: int, offset: int, *, @@ -831,7 +831,7 @@ async def load_slice( @classmethod async def load_all( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, domain_name: Optional[str] = None, user_uuid: Optional[UUID] = None, @@ -852,7 +852,7 @@ async def load_all( @classmethod async def load_item( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, endpoint_id: UUID, domain_name: Optional[str] = None, @@ -1111,7 +1111,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, endpoint_id: UUID, props: ModifyEndpointInput, @@ -1149,7 +1149,7 @@ class Meta: @classmethod async def from_row( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, row: EndpointTokenRow, ) -> Self: return cls( @@ -1164,7 +1164,7 @@ async def from_row( @classmethod async def load_count( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, endpoint_id: Optional[UUID] = None, project: Optional[UUID] = None, @@ -1182,12 +1182,12 @@ async def load_count( query = query.filter(EndpointTokenRow.session_owner == user_uuid) async with ctx.db.begin_readonly() as conn: result = await conn.execute(query) - return result.scalar() + return result.scalar() or 0 @classmethod async def load_slice( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, limit: int, offset: int, *, @@ -1247,7 +1247,7 @@ async def load_all( @classmethod async def load_item( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, token: str, *, project: Optional[UUID] = None, diff --git a/src/ai/backend/manager/api/gql_legacy/fields.py b/src/ai/backend/manager/api/gql_legacy/fields.py index fb5e4ea1c2a..778274d11f3 100644 --- a/src/ai/backend/manager/api/gql_legacy/fields.py +++ b/src/ai/backend/manager/api/gql_legacy/fields.py @@ -27,7 +27,7 @@ def serialize(val: ScopeType) -> str: return val.serialize() @staticmethod - def parse_literal(node: Any, _variables=None) -> Optional[ScopeType]: + def parse_literal(node: Any, _variables: dict | None = None) -> Optional[ScopeType]: if isinstance(node, graphql.language.ast.StringValueNode): return deserialize_scope(node.value) return None diff --git a/src/ai/backend/manager/api/gql_legacy/gql_relay.py b/src/ai/backend/manager/api/gql_legacy/gql_relay.py index a14304e19e3..64e79d7f3de 100644 --- a/src/ai/backend/manager/api/gql_legacy/gql_relay.py +++ b/src/ai/backend/manager/api/gql_legacy/gql_relay.py @@ -22,7 +22,6 @@ ) from graphene.relay.node import Node, NodeField, is_node from graphene.types import Field, NonNull, ObjectType, String -from graphene.types.objecttype import ObjectTypeMeta from graphene.types.utils import get_type from graphql_relay.utils import base64, unbase64 @@ -37,7 +36,7 @@ def get_edge_class( base_name: str, strict_types: bool = False, description: str | None = None, -) -> ObjectTypeMeta: +) -> type[ObjectType]: edge_class = getattr(connection_class, "Edge", None) class EdgeBase: @@ -116,7 +115,7 @@ def __call__( class AsyncNodeField(NodeField): - def wrap_resolve(self, parent_resolver) -> functools.partial: + def wrap_resolve(self, parent_resolver: Callable) -> functools.partial: return functools.partial(self.node_type.node_resolver, get_type(self.field_type)) @@ -144,17 +143,19 @@ async def node_resolver(cls, only_type: type, root: Any, info: Any, id: str) -> return await cls.get_node_from_global_id(info, id, only_type=only_type) @staticmethod - def to_global_id(type_, id_) -> str: + def to_global_id(type_: str, id_: str) -> str: if id_ is None: raise Exception("Encoding None value as Global ID is not allowed.") return base64(f"{type_}:{id_}") @classmethod - def resolve_global_id(cls, info, global_id: str) -> tuple[str, str]: + def resolve_global_id(cls, info: Any, global_id: str) -> tuple[str, str]: return _resolve_global_id(global_id) @classmethod - async def get_node_from_global_id(cls, info, global_id: str, only_type=None) -> Any: + async def get_node_from_global_id( + cls, info: Any, global_id: str, only_type: type | None = None + ) -> Any: _type, _ = cls.resolve_global_id(info, global_id) graphene_type = info.schema.get_type(_type) @@ -164,10 +165,15 @@ async def get_node_from_global_id(cls, info, global_id: str, only_type=None) -> graphene_type = graphene_type.graphene_type if only_type: + # Use hasattr to check for _meta attribute safely + if not hasattr(only_type, "_meta"): + raise ServerMisconfiguredError("GraphQL type missing _meta attribute") if graphene_type != only_type: - raise InvalidAPIParameters(f"Must receive a {only_type._meta.name} id.") + only_type_name = getattr(getattr(only_type, "_meta", None), "name", str(only_type)) + raise InvalidAPIParameters(f"Must receive a {only_type_name} id.") - if cls not in graphene_type._meta.interfaces: + _meta = getattr(graphene_type, "_meta", None) + if _meta is None or cls not in _meta.interfaces: raise Exception(f'ObjectType "{_type}" does not implement the "{cls}" interface.') get_node = getattr(graphene_type, "get_node", None) @@ -302,7 +308,7 @@ def type(self) -> Any: @classmethod def resolve_connection( cls, - connection_type: ConnectionConstructor, + connection_type: Any, args: dict[str, Any] | None, resolver_result: ConnectionResolverResult, ) -> Connection: @@ -325,7 +331,7 @@ def resolve_connection( resolved = resolved[:page_size] if pagination_order == ConnectionPaginationOrder.BACKWARD: resolved = resolved[::-1] - edge_type = connection_type.Edge + edge_type = connection_type.Edge # type: ignore[attr-defined] edges = [ edge_type( node=value, @@ -333,7 +339,7 @@ def resolve_connection( ) for value in resolved ] - return connection_type( + return connection_type( # type: ignore[operator] edges=edges, page_info=PageInfo( start_cursor=edges[0].cursor if edges else None, @@ -356,10 +362,10 @@ def resolve_connection( async def connection_resolver( cls, resolver: Resolver, - connection_type: ConnectionConstructor, - root, - info, - **args, + connection_type: Any, + root: Any, + info: graphene.ResolveInfo, + **args: Any, ) -> Connection: _result = resolver(root, info, **args) match _result: @@ -405,7 +411,9 @@ def serialize(val: Any) -> Any: return val @staticmethod - def parse_literal(node: Any, _variables=None) -> ResolvedGlobalID | None: + def parse_literal( + node: Any, _variables: dict[str, Any] | None = None + ) -> ResolvedGlobalID | None: if isinstance(node, graphql.language.ast.StringValueNode): return _from_str(node.value) return None diff --git a/src/ai/backend/manager/api/gql_legacy/group.py b/src/ai/backend/manager/api/gql_legacy/group.py index dbd738b890a..361702116fc 100644 --- a/src/ai/backend/manager/api/gql_legacy/group.py +++ b/src/ai/backend/manager/api/gql_legacy/group.py @@ -236,7 +236,7 @@ async def resolve_registry_quota(self, info: graphene.ResolveInfo) -> int: ) @classmethod - async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: + async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self: graph_ctx: GraphQueryContext = info.context _, group_id = AsyncNode.resolve_global_id(info, id) query = sa.select(GroupRow).where(GroupRow.id == group_id) @@ -559,7 +559,7 @@ class GroupInput(graphene.InputObjectType): ) def to_action(self, name: str) -> CreateGroupAction: - def value_or_none(value) -> Any: + def value_or_none(value: Any) -> Any: return value if value is not Undefined else None type_val = None if self.type is Undefined else ProjectType[self.type] @@ -673,7 +673,7 @@ class Arguments: ) async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: GroupInput, @@ -711,7 +711,7 @@ class Arguments: ) async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, gid: uuid.UUID, props: ModifyGroupInput, @@ -745,7 +745,7 @@ class Arguments: UserRole.ADMIN, lambda gid, **kwargs: (None, gid), ) - async def mutate(cls, root, info: graphene.ResolveInfo, gid: uuid.UUID) -> DeleteGroup: + async def mutate(cls, root: Any, info: graphene.ResolveInfo, gid: uuid.UUID) -> DeleteGroup: ctx: GraphQueryContext = info.context await ctx.processors.group.delete_group.wait_for_complete(DeleteGroupAction(gid)) return cls(ok=True, msg="success") @@ -773,7 +773,7 @@ class Arguments: UserRole.ADMIN, lambda gid, **kwargs: (None, gid), ) - async def mutate(cls, root, info: graphene.ResolveInfo, gid: uuid.UUID) -> PurgeGroup: + async def mutate(cls, root: Any, info: graphene.ResolveInfo, gid: uuid.UUID) -> PurgeGroup: graph_ctx: GraphQueryContext = info.context await graph_ctx.processors.group.purge_group.wait_for_complete(PurgeGroupAction(gid)) diff --git a/src/ai/backend/manager/api/gql_legacy/image.py b/src/ai/backend/manager/api/gql_legacy/image.py index de90ed0b3b4..2f368c7fc6c 100644 --- a/src/ai/backend/manager/api/gql_legacy/image.py +++ b/src/ai/backend/manager/api/gql_legacy/image.py @@ -525,19 +525,27 @@ def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ... @overload @classmethod def from_row( - cls, graph_ctx, row: ImageRow, *, permissions: Optional[Iterable[ImagePermission]] = None + cls, + graph_ctx: GraphQueryContext, + row: ImageRow, + *, + permissions: Optional[Iterable[ImagePermission]] = None, ) -> ImageNode: ... @overload @classmethod def from_row( - cls, graph_ctx, row: None, *, permissions: Optional[Iterable[ImagePermission]] = None + cls, + graph_ctx: GraphQueryContext, + row: None, + *, + permissions: Optional[Iterable[ImagePermission]] = None, ) -> None: ... @classmethod def from_row( cls, - graph_ctx, + graph_ctx: GraphQueryContext, row: Optional[ImageRow], *, permissions: Optional[Iterable[ImagePermission]] = None, @@ -722,7 +730,7 @@ async def get_connection( return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) # TODO: Introduce access control logic considering scope and permission - async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> Image: + async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs: Any) -> Image: ctx: GraphQueryContext = info.context _, image_id = AsyncNode.resolve_global_id(info, self.id) action_result = await ctx.processors.image.get_image_by_id.wait_for_complete( diff --git a/src/ai/backend/manager/api/gql_legacy/keypair.py b/src/ai/backend/manager/api/gql_legacy/keypair.py index ce6c40e1a64..981cdef8653 100644 --- a/src/ai/backend/manager/api/gql_legacy/keypair.py +++ b/src/ai/backend/manager/api/gql_legacy/keypair.py @@ -502,7 +502,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, user_id: str, props: KeyPairInput, @@ -531,7 +531,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, access_key: AccessKey, props: ModifyKeyPairInput, @@ -559,7 +559,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, access_key: AccessKey, ) -> DeleteKeyPair: diff --git a/src/ai/backend/manager/api/gql_legacy/resource_preset.py b/src/ai/backend/manager/api/gql_legacy/resource_preset.py index 0d87653ff9d..e075758c8cd 100644 --- a/src/ai/backend/manager/api/gql_legacy/resource_preset.py +++ b/src/ai/backend/manager/api/gql_legacy/resource_preset.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID import graphene @@ -240,7 +240,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: CreateResourcePresetInput, @@ -278,7 +278,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: Optional[UUID], name: Optional[str], @@ -316,7 +316,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, id: Optional[UUID], name: Optional[str], diff --git a/src/ai/backend/manager/api/gql_legacy/routing.py b/src/ai/backend/manager/api/gql_legacy/routing.py index c825020c6cf..7b2fbc7086e 100644 --- a/src/ai/backend/manager/api/gql_legacy/routing.py +++ b/src/ai/backend/manager/api/gql_legacy/routing.py @@ -41,7 +41,7 @@ class Meta: _endpoint_row: EndpointRow @classmethod - def from_dto(cls, dto) -> Optional[Self]: # type: ignore + def from_dto(cls, dto: Any) -> Optional[Self]: if dto is None: return None return cls( @@ -57,7 +57,7 @@ def from_dto(cls, dto) -> Optional[Self]: # type: ignore @classmethod async def from_row( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, row: RoutingRow, endpoint: Optional[EndpointRow] = None, ) -> Routing: @@ -76,7 +76,7 @@ async def from_row( @classmethod async def load_count( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, endpoint_id: Optional[uuid.UUID] = None, project: Optional[uuid.UUID] = None, @@ -94,12 +94,12 @@ async def load_count( query = query.filter(RoutingRow.session_owner == user_uuid) async with ctx.db.begin_readonly() as conn: result = await conn.execute(query) - return result.scalar() + return result.scalar() or 0 @classmethod async def load_slice( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, limit: int, offset: int, *, @@ -134,12 +134,14 @@ async def load_slice( query = parser.append_ordering(query, order) """ async with ctx.db.begin_readonly_session() as session: - return [await cls.from_row(ctx, row) async for row in (await session.stream(query))] + return [ + await cls.from_row(ctx, row) async for row in (await session.stream_scalars(query)) + ] @classmethod async def load_all( cls, - ctx, # ctx: GraphQueryContext + ctx: GraphQueryContext, endpoint_id: uuid.UUID, *, project: Optional[uuid.UUID] = None, @@ -159,7 +161,7 @@ async def load_all( @classmethod async def load_item( cls, - ctx, # ctx: GraphQueryContext, + ctx: GraphQueryContext, *, routing_id: uuid.UUID, project: Optional[uuid.UUID] = None, diff --git a/src/ai/backend/manager/api/gql_legacy/scaling_group.py b/src/ai/backend/manager/api/gql_legacy/scaling_group.py index b6d02eb2a7c..b7acfc2c921 100644 --- a/src/ai/backend/manager/api/gql_legacy/scaling_group.py +++ b/src/ai/backend/manager/api/gql_legacy/scaling_group.py @@ -701,7 +701,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: CreateScalingGroupInput, @@ -758,7 +758,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, props: ModifyScalingGroupInput, @@ -782,7 +782,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, name: str, ) -> DeleteScalingGroup: @@ -808,7 +808,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, domain: str, @@ -845,7 +845,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], domain: str, @@ -881,7 +881,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, domain: str, @@ -916,7 +916,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], domain: str, @@ -948,7 +948,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, domain: str, ) -> DisassociateAllScalingGroupsWithDomain: @@ -979,7 +979,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, user_group: uuid.UUID, @@ -1016,7 +1016,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], user_group: uuid.UUID, @@ -1052,7 +1052,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, user_group: uuid.UUID, @@ -1087,7 +1087,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], user_group: uuid.UUID, @@ -1119,7 +1119,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, user_group: uuid.UUID, ) -> DisassociateAllScalingGroupsWithGroup: @@ -1150,7 +1150,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, access_key: str, @@ -1187,7 +1187,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], access_key: str, @@ -1223,7 +1223,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_group: str, access_key: str, @@ -1258,7 +1258,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, scaling_groups: Sequence[str], access_key: str, diff --git a/src/ai/backend/manager/api/gql_legacy/schema.py b/src/ai/backend/manager/api/gql_legacy/schema.py index 21630102060..e7fd40aab69 100644 --- a/src/ai/backend/manager/api/gql_legacy/schema.py +++ b/src/ai/backend/manager/api/gql_legacy/schema.py @@ -4,7 +4,7 @@ import logging import time import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast import attrs @@ -1757,7 +1757,7 @@ async def resolve_images( info: graphene.ResolveInfo, *, is_installed: bool | None = None, - is_operation=False, + is_operation: bool = False, filter_by_statuses: Optional[list[ImageStatus]] = None, load_filters: list[str] | None = None, image_filters: list[str] | None = None, @@ -2289,12 +2289,14 @@ async def resolve_scaling_groups_for_domain( async def resolve_scaling_groups_for_user_group( root: Any, info: graphene.ResolveInfo, - user_group, + user_group: str, is_active: Optional[bool] = None, ) -> Sequence[ScalingGroup]: + from uuid import UUID + return await ScalingGroup.load_by_group( info.context, - user_group, + UUID(user_group), is_active=is_active, ) @@ -3294,7 +3296,9 @@ async def resolve_container_utilization_metric_metadata( class GQLMutationPrivilegeCheckMiddleware: - def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + def resolve( + self, next: Callable[..., Any], root: Any, info: graphene.ResolveInfo, **args: Any + ) -> Any: graph_ctx: GraphQueryContext = info.context if info.operation.operation == OperationType.MUTATION: mutation_field: GraphQLField | None = getattr(Mutation, info.field_name, None) @@ -3311,7 +3315,9 @@ def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: class GQLExceptionMiddleware: - def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + def resolve( + self, next: Callable[..., Any], root: Any, info: graphene.ResolveInfo, **args: Any + ) -> Any: try: res = next(root, info, **args) except BackendAIError as e: @@ -3337,7 +3343,9 @@ def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: class GQLMetricMiddleware: - def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + def resolve( + self, next: Callable[..., Any], root: Any, info: graphene.ResolveInfo, **args: Any + ) -> Any: graph_ctx: GraphQueryContext = info.context operation_type = info.operation.operation field_name = info.field_name diff --git a/src/ai/backend/manager/api/gql_legacy/service_config.py b/src/ai/backend/manager/api/gql_legacy/service_config.py index ea7341fc7ea..04603a797d6 100644 --- a/src/ai/backend/manager/api/gql_legacy/service_config.py +++ b/src/ai/backend/manager/api/gql_legacy/service_config.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import TYPE_CHECKING, Final, Optional, Self +from typing import TYPE_CHECKING, Any, Final, Optional, Self import graphene @@ -76,7 +76,7 @@ class Meta: async def load(cls, info: graphene.ResolveInfo, service: str) -> Self: ctx: GraphQueryContext = info.context - def _fallback(x) -> str: + def _fallback(x: Any) -> str: return str(x) unified_config = ctx.config_provider.config.model_dump( @@ -182,7 +182,7 @@ def _get_etcd_prefix_key(cls, service: str) -> str: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, input: ModifyServiceConfigNodeInput, ) -> ModifyServiceConfigNodePayload: diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index eeaf5376714..9d88dc5a1d4 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -339,7 +339,7 @@ def from_row( *, permissions: Optional[Iterable[ComputeSessionPermission]] = None, ) -> Self: - status_history = row.status_history or {} + status_history: dict[str, Any] = dict(row.status_history) if row.status_history else {} raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name) result = cls( # identity @@ -451,7 +451,7 @@ def from_dataclass( return result async def __resolve_reference( - self, info: graphene.ResolveInfo, **kwargs + self, info: graphene.ResolveInfo, **kwargs: Any ) -> Optional[ComputeSessionNode]: # TODO: Confirm if scope and permsission are correct # Parse the global ID from Federation (converts base64 encoded string to tuple) @@ -840,7 +840,7 @@ async def mutate_and_get_payload( cls, root: Any, info: graphene.ResolveInfo, - **input, + **input: Any, ) -> ModifyComputeSession: graph_ctx: GraphQueryContext = info.context _, raw_session_id = cast(ResolvedGlobalID, input["id"]) @@ -897,7 +897,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, input: CheckAndTransitStatusInput, ) -> CheckAndTransitStatus: diff --git a/src/ai/backend/manager/api/gql_legacy/user.py b/src/ai/backend/manager/api/gql_legacy/user.py index 05fff7995a4..ca05db3779c 100644 --- a/src/ai/backend/manager/api/gql_legacy/user.py +++ b/src/ai/backend/manager/api/gql_legacy/user.py @@ -192,7 +192,7 @@ def from_dataclass(cls, ctx: GraphQueryContext, user_data: UserData) -> Self: ) @classmethod - async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: + async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self: graph_ctx: GraphQueryContext = info.context _, user_id = AsyncNode.resolve_global_id(info, id) @@ -464,7 +464,7 @@ async def resolve_project_nodes( result.append(GroupNode.from_row(graph_ctx, prj_row)) return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) - async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> UserNode: + async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs: Any) -> UserNode: return await UserNode.get_node(info, self.id) @@ -1074,7 +1074,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, email: str, props: UserInput, @@ -1111,7 +1111,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, email: str, props: ModifyUserInput, @@ -1150,7 +1150,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, email: str, ) -> DeleteUser: @@ -1191,7 +1191,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, email: str, props: PurgeUserInput, diff --git a/src/ai/backend/manager/api/gql_legacy/vfolder.py b/src/ai/backend/manager/api/gql_legacy/vfolder.py index 24122e8f007..26935ffc1bd 100644 --- a/src/ai/backend/manager/api/gql_legacy/vfolder.py +++ b/src/ai/backend/manager/api/gql_legacy/vfolder.py @@ -439,7 +439,9 @@ async def get_accessible_connection( ] return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) - async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> VirtualFolderNode: + async def __resolve_reference( + self, info: graphene.ResolveInfo, **kwargs: Any + ) -> VirtualFolderNode: vfolder_node = await VirtualFolderNode.get_node(info, self.id) if vfolder_node is None: raise VFolderNotFound(f"Virtual folder not found: {self.id}") @@ -1516,7 +1518,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, quota_scope_id: str, storage_host_name: str, @@ -1565,7 +1567,7 @@ class Arguments: @classmethod async def mutate( cls, - root, + root: Any, info: graphene.ResolveInfo, quota_scope_id: str, storage_host_name: str, diff --git a/src/ai/backend/manager/api/manager.py b/src/ai/backend/manager/api/manager.py index f5eb7cb04a3..228d8cabe71 100644 --- a/src/ai/backend/manager/api/manager.py +++ b/src/ai/backend/manager/api/manager.py @@ -60,9 +60,9 @@ class SchedulerOps(enum.Enum): def server_status_required( allowed_status: frozenset[ManagerStatus], ) -> Callable[[Handler], Handler]: - def decorator(handler) -> Handler: + def decorator(handler: Handler) -> Handler: @functools.wraps(handler) - async def wrapped(request, *args, **kwargs) -> web.StreamResponse: + async def wrapped(request: web.Request, *args: Any, **kwargs: Any) -> web.StreamResponse: root_ctx: RootContext = request.app["_root.context"] status = await root_ctx.config_provider.legacy_etcd_config_loader.get_manager_status() if status not in allowed_status: @@ -85,7 +85,7 @@ async def wrapped(request, *args, **kwargs) -> web.StreamResponse: class GQLMutationUnfrozenRequiredMiddleware: - def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + def resolve(self, next: Callable, root: Any, info: graphene.ResolveInfo, **args: Any) -> Any: graph_ctx: GraphQueryContext = info.context if ( info.operation.operation == "mutation" diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 35c298b3fe5..cacb75c4ed3 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -336,20 +336,20 @@ def check_and_return(self, value: Any) -> object: }).allow_extra("*") -def sub(d, old, new) -> dict: +def sub(d: dict[Any, Any], old: Any, new: Any) -> dict[Any, Any]: for k, v in d.items(): if isinstance(v, (Mapping, dict)): - d[k] = sub(v, old, new) + d[k] = sub(dict(v), old, new) elif d[k] == old: d[k] = new return d -def drop_undefined(d) -> dict: - newd = {} +def drop_undefined(d: dict[Any, Any]) -> dict[Any, Any]: + newd: dict[Any, Any] = {} for k, v in d.items(): if isinstance(v, (Mapping, dict)): - newval = drop_undefined(v) + newval = drop_undefined(dict(v)) if len(newval.keys()) > 0: # exclude empty dict always newd[k] = newval elif not isinstance(v, Undefined): diff --git a/src/ai/backend/manager/api/stream.py b/src/ai/backend/manager/api/stream.py index 4e6219760d2..f1a7a083e30 100644 --- a/src/ai/backend/manager/api/stream.py +++ b/src/ai/backend/manager/api/stream.py @@ -15,7 +15,7 @@ import uuid import weakref from collections import defaultdict -from collections.abc import AsyncIterator, Iterable, Mapping, MutableMapping +from collections.abc import AsyncIterator, Callable, Iterable, Mapping, MutableMapping from datetime import timedelta from typing import ( TYPE_CHECKING, @@ -69,7 +69,9 @@ @server_status_required(READ_ALLOWED) @auth_required @adefer -async def stream_pty(defer, request: web.Request) -> web.StreamResponse: +async def stream_pty( + defer: Callable[[Callable[[], None]], None], request: web.Request +) -> web.StreamResponse: root_ctx: RootContext = request.app["_root.context"] app_ctx: PrivateContext = request.app["stream.context"] database_ptask_group: aiotools.PersistentTaskGroup = request.app["database_ptask_group"] @@ -276,7 +278,9 @@ async def stream_stdout() -> None: @server_status_required(READ_ALLOWED) @auth_required @adefer -async def stream_execute(defer, request: web.Request) -> web.StreamResponse: +async def stream_execute( + defer: Callable[[Callable[[], None]], None], request: web.Request +) -> web.StreamResponse: """ WebSocket-version of gateway.kernel.execute(). """ @@ -423,7 +427,7 @@ async def stream_execute(defer, request: web.Request) -> web.StreamResponse: ) @adefer async def stream_proxy( - defer, request: web.Request, params: Mapping[str, Any] + defer: Callable[[Callable[[], None]], None], request: web.Request, params: Mapping[str, Any] ) -> web.StreamResponse: root_ctx: RootContext = request.app["_root.context"] app_ctx: PrivateContext = request.app["stream.context"] diff --git a/src/ai/backend/manager/api/utils.py b/src/ai/backend/manager/api/utils.py index c6302751aa6..21968785eb9 100644 --- a/src/ai/backend/manager/api/utils.py +++ b/src/ai/backend/manager/api/utils.py @@ -13,7 +13,15 @@ import traceback import uuid from collections import defaultdict -from collections.abc import Awaitable, Callable, Generator, Hashable, Mapping, MutableMapping +from collections.abc import ( + Awaitable, + Callable, + Generator, + Hashable, + Iterable, + Mapping, + MutableMapping, +) from typing import ( TYPE_CHECKING, Annotated, @@ -57,8 +65,8 @@ _rx_sitepkg_path = re.compile(r"^.+/site-packages/") -def method_placeholder(orig_method) -> Callable[[web.Request], Awaitable[web.Response]]: - async def _handler(request) -> web.Response: +def method_placeholder(orig_method: str) -> Callable[[web.Request], Awaitable[web.Response]]: + async def _handler(request: web.Request) -> web.Response: raise web.HTTPMethodNotAllowed(request.method, [orig_method]) return _handler @@ -374,16 +382,16 @@ def trim_text(value: str, maxlen: int) -> str: class _Infinity(numbers.Number): - def __lt__(self, o) -> bool: + def __lt__(self, o: Any) -> bool: return False - def __le__(self, o) -> bool: + def __le__(self, o: Any) -> bool: return False - def __gt__(self, o) -> bool: + def __gt__(self, o: Any) -> bool: return True - def __ge__(self, o) -> bool: + def __ge__(self, o: Any) -> bool: return False def __float__(self) -> float: @@ -400,7 +408,7 @@ def __hash__(self) -> int: Infinity = _Infinity() -def prettify_traceback(exc) -> str: +def prettify_traceback(exc: BaseException | None) -> str: # Make a compact stack trace string with io.StringIO() as buf: while exc is not None: @@ -415,10 +423,10 @@ def prettify_traceback(exc) -> str: return f"Traceback:\n{buf.getvalue()}" -def catch_unexpected(log, reraise_cancellation: bool = True, raven=None) -> Callable: - def _wrap(func) -> Callable: +def catch_unexpected(log: Any, reraise_cancellation: bool = True, raven: Any = None) -> Callable: + def _wrap(func: Callable) -> Callable: @functools.wraps(func) - async def _wrapped(*args, **kwargs) -> Any: + async def _wrapped(*args: Any, **kwargs: Any) -> Any: try: return await func(*args, **kwargs) except asyncio.CancelledError: @@ -435,7 +443,7 @@ async def _wrapped(*args, **kwargs) -> Any: return _wrap -def set_handler_attr(func, key, value) -> None: +def set_handler_attr(func: Any, key: str, value: Any) -> None: attrs = getattr(func, "_backend_attrs", None) if attrs is None: attrs = {} @@ -443,7 +451,7 @@ def set_handler_attr(func, key, value) -> None: func._backend_attrs = attrs -def get_handler_attr(request, key, default=None) -> Any: +def get_handler_attr(request: web.Request, key: str, default: Any = None) -> Any: # When used in the aiohttp server-side codes, we should use # request.match_info.hanlder instead of handler passed to the middleware # functions because aiohttp wraps this original handler with functools.partial @@ -465,7 +473,7 @@ async def deprecated_stub_impl(request: web.Request) -> web.Response: return deprecated_stub_impl -def chunked(iterable, n) -> Generator[tuple, None, None]: +def chunked(iterable: Iterable, n: int) -> Generator[tuple, None, None]: it = iter(iterable) while True: chunk = tuple(itertools.islice(it, n)) @@ -532,7 +540,7 @@ async def call_non_bursty( class Singleton(type): _instances: MutableMapping[Any, Any] = {} - def __call__(cls, *args, **kwargs) -> Any: + def __call__(cls, *args: Any, **kwargs: Any) -> Any: if cls not in cls._instances: cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] diff --git a/src/ai/backend/manager/api/vfolder.py b/src/ai/backend/manager/api/vfolder.py index 42d53341172..0c4ea38f319 100644 --- a/src/ai/backend/manager/api/vfolder.py +++ b/src/ai/backend/manager/api/vfolder.py @@ -141,6 +141,7 @@ from .auth import admin_required, auth_required, superadmin_required from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required +from .types import CORSOptions from .utils import ( LegacyBaseRequestModel, LegacyBaseResponseModel, @@ -1720,8 +1721,10 @@ async def _delete( allowed_vfolder_types = ( await root_ctx.config_provider.legacy_etcd_config_loader.get_vfolder_types() ) + # Get connection from session + conn = await db_session.connection() await ensure_host_permission_allowed( - db_session.bind, + conn, folder_host, allowed_vfolder_types=allowed_vfolder_types, user_uuid=user_uuid, @@ -2816,7 +2819,7 @@ async def shutdown(app: web.Application) -> None: await app_ctx.storage_ptask_group.shutdown() -def create_app(default_cors_options) -> tuple[web.Application, list]: +def create_app(default_cors_options: CORSOptions) -> tuple[web.Application, list]: app = web.Application() app["prefix"] = "folders" app["api_versions"] = (2, 3, 4) diff --git a/src/ai/backend/manager/api/wsproxy.py b/src/ai/backend/manager/api/wsproxy.py index d28e43735c3..56581de9cd3 100644 --- a/src/ai/backend/manager/api/wsproxy.py +++ b/src/ai/backend/manager/api/wsproxy.py @@ -66,7 +66,7 @@ class TCPProxy(ServiceProxy): "down_task", ) - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.down_task: Optional[asyncio.Task] = None diff --git a/src/ai/backend/manager/cli/__main__.py b/src/ai/backend/manager/cli/__main__.py index ca20d5ad947..c6d39161b07 100644 --- a/src/ai/backend/manager/cli/__main__.py +++ b/src/ai/backend/manager/cli/__main__.py @@ -99,7 +99,9 @@ def main( ) @click.argument("psql_args", nargs=-1, type=click.UNPROCESSED) @click.pass_obj -def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args) -> None: +def dbshell( + cli_ctx: CLIContext, container_name: str | None, psql_help: bool, psql_args: tuple[str, ...] +) -> None: """ Run the database shell. @@ -118,8 +120,11 @@ def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args) -> None: bootstrap_config = asyncio.run(cli_ctx.get_bootstrap_config()) db_config = bootstrap_config.db + _psql_args: list[str] if psql_help: - psql_args = ["--help"] + _psql_args = ["--help"] + else: + _psql_args = list(psql_args) if not container_name: # Try to get the database container name of the halfstack candidate_container_names = subprocess.check_output( @@ -151,7 +156,7 @@ def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args) -> None: cmd = [ "psql", (f"postgres://{db_config.user}:{db_config.password}@{db_config.addr}/{db_config.name}"), - *psql_args, + *_psql_args, ] subprocess.run(cmd) return @@ -168,7 +173,7 @@ def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args) -> None: db_config.user, "-d", db_config.name, - *psql_args, + *_psql_args, ] subprocess.run(cmd) @@ -246,7 +251,7 @@ def generate_rpc_keypair(cli_ctx: CLIContext, dst_dir: pathlib.Path, name: str) ), ) @click.pass_obj -def clear_history(cli_ctx: CLIContext, retention, vacuum_full) -> None: +def clear_history(cli_ctx: CLIContext, retention: str, vacuum_full: bool) -> None: """ Delete old records from the kernels, error_logs tables and invoke the PostgreSQL's vaccuum operation to clear up the actual disk space. diff --git a/src/ai/backend/manager/cli/api.py b/src/ai/backend/manager/cli/api.py index de3937359bf..e368b069ebe 100644 --- a/src/ai/backend/manager/cli/api.py +++ b/src/ai/backend/manager/cli/api.py @@ -21,7 +21,7 @@ @click.group() -def cli(args) -> None: +def cli() -> None: pass diff --git a/src/ai/backend/manager/cli/dbschema.py b/src/ai/backend/manager/cli/dbschema.py index 507ec75289b..eb729e9f4ec 100644 --- a/src/ai/backend/manager/cli/dbschema.py +++ b/src/ai/backend/manager/cli/dbschema.py @@ -32,7 +32,7 @@ class RevisionHistory(TypedDict): @click.group() -def cli(args) -> None: +def cli() -> None: pass @@ -45,7 +45,7 @@ def cli(args) -> None: help="The path to Alembic config file. [default: alembic.ini]", ) @click.pass_obj -def show(cli_ctx: CLIContext, alembic_config) -> None: +def show(cli_ctx: CLIContext, alembic_config: str) -> None: """Show the current schema information.""" from alembic.config import Config from alembic.runtime.migration import MigrationContext diff --git a/src/ai/backend/manager/cli/etcd.py b/src/ai/backend/manager/cli/etcd.py index 7b4aa9c1eaa..4757796e455 100644 --- a/src/ai/backend/manager/cli/etcd.py +++ b/src/ai/backend/manager/cli/etcd.py @@ -4,7 +4,8 @@ import json import logging import sys -from typing import TYPE_CHECKING +from decimal import Decimal +from typing import TYPE_CHECKING, BinaryIO import click @@ -47,7 +48,7 @@ def cli() -> None: help="The configuration scope to put the value.", ) @click.pass_obj -def put(cli_ctx: CLIContext, key, value, scope) -> None: +def put(cli_ctx: CLIContext, key: str, value: str, scope: ConfigScopes) -> None: """Put a single key-value pair into the etcd.""" async def _impl() -> None: @@ -71,7 +72,7 @@ async def _impl() -> None: help="The configuration scope to put the value.", ) @click.pass_obj -def put_json(cli_ctx: CLIContext, key, file, scope) -> None: +def put_json(cli_ctx: CLIContext, key: str, file: BinaryIO, scope: ConfigScopes) -> None: """ Put a JSON object from FILE to the etcd as flattened key-value pairs under the given KEY prefix. @@ -103,7 +104,9 @@ async def _impl() -> None: ), ) @click.pass_obj -def move_subtree(cli_ctx: CLIContext, src_prefix, dst_prefix, scope) -> None: +def move_subtree( + cli_ctx: CLIContext, src_prefix: str, dst_prefix: str, scope: ConfigScopes +) -> None: """ Move a subtree to another key prefix. """ @@ -135,7 +138,7 @@ async def _impl() -> None: help="The configuration scope to put the value.", ) @click.pass_obj -def get(cli_ctx: CLIContext, key, prefix, scope) -> None: +def get(cli_ctx: CLIContext, key: str, prefix: bool, scope: ConfigScopes) -> None: """ Get the value of a key in the configured etcd namespace. """ @@ -168,7 +171,7 @@ async def _impl() -> None: help="The configuration scope to put the value.", ) @click.pass_obj -def delete(cli_ctx: CLIContext, key, prefix, scope) -> None: +def delete(cli_ctx: CLIContext, key: str, prefix: bool, scope: ConfigScopes) -> None: """Delete the key in the configured etcd namespace.""" async def _impl() -> None: @@ -198,7 +201,7 @@ async def _impl() -> None: @click.option("-s", "--short", is_flag=True, help="Show only the image references and digests.") @click.option("-i", "--installed", is_flag=True, help="Show only the installed images.") @click.pass_obj -def list_images(cli_ctx, short, installed) -> None: +def list_images(cli_ctx: CLIContext, short: bool, installed: bool) -> None: """List all configured images.""" log.warning("etcd list-images command is deprecated, use image list instead") asyncio.run(list_images_impl(cli_ctx, short, installed)) @@ -208,7 +211,7 @@ def list_images(cli_ctx, short, installed) -> None: @click.argument("canonical_or_alias") @click.argument("architecture") @click.pass_obj -def inspect_image(cli_ctx, canonical_or_alias, architecture) -> None: +def inspect_image(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: """Show the details of the given image or alias.""" log.warning("etcd inspect-image command is deprecated, use image inspect instead") asyncio.run(inspect_image_impl(cli_ctx, canonical_or_alias, architecture)) @@ -218,7 +221,7 @@ def inspect_image(cli_ctx, canonical_or_alias, architecture) -> None: @click.argument("canonical_or_alias") @click.argument("architecture") @click.pass_obj -def forget_image(cli_ctx, canonical_or_alias, architecture) -> None: +def forget_image(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: """Forget (delete) a specific image.""" log.warning("etcd forget-image command is deprecated, use image forget instead") asyncio.run(forget_image_impl(cli_ctx, canonical_or_alias, architecture)) @@ -231,11 +234,11 @@ def forget_image(cli_ctx, canonical_or_alias, architecture) -> None: @click.argument("architecture") @click.pass_obj def set_image_resource_limit( - cli_ctx, - canonical_or_alias, - slot_type, - range_value, - architecture, + cli_ctx: CLIContext, + canonical_or_alias: str, + slot_type: str, + range_value: tuple[Decimal | None, Decimal | None], + architecture: str, ) -> None: """Set the MIN:MAX values of a SLOT_TYPE limit for the given image REFERENCE.""" log.warning( diff --git a/src/ai/backend/manager/cli/gql.py b/src/ai/backend/manager/cli/gql.py index 8d82c77db65..55b4eca5ffb 100644 --- a/src/ai/backend/manager/cli/gql.py +++ b/src/ai/backend/manager/cli/gql.py @@ -16,7 +16,7 @@ @click.group() -def cli(args) -> None: +def cli() -> None: pass diff --git a/src/ai/backend/manager/cli/image.py b/src/ai/backend/manager/cli/image.py index fb78c8da1a5..9a0cfc66f60 100644 --- a/src/ai/backend/manager/cli/image.py +++ b/src/ai/backend/manager/cli/image.py @@ -2,6 +2,7 @@ import asyncio import logging +from decimal import Decimal from typing import Optional import click @@ -33,7 +34,7 @@ def cli() -> None: @click.option("-s", "--short", is_flag=True, help="Show only the image references and digests.") @click.option("-i", "--installed", is_flag=True, help="Show only the installed images.") @click.pass_obj -def list_images(cli_ctx, short, installed) -> None: +def list_images(cli_ctx: CLIContext, short: bool, installed: bool) -> None: """List all configured images.""" asyncio.run(list_images_impl(cli_ctx, short, installed)) @@ -42,7 +43,7 @@ def list_images(cli_ctx, short, installed) -> None: @click.argument("canonical_or_alias") @click.argument("architecture") @click.pass_obj -def inspect(cli_ctx, canonical_or_alias, architecture) -> None: +def inspect(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: """Show the details of the given image or alias.""" asyncio.run(inspect_image_impl(cli_ctx, canonical_or_alias, architecture)) @@ -51,7 +52,7 @@ def inspect(cli_ctx, canonical_or_alias, architecture) -> None: @click.argument("canonical_or_alias") @click.argument("architecture") @click.pass_obj -def forget(cli_ctx, canonical_or_alias, architecture) -> None: +def forget(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: """Forget (soft-delete) a specific image.""" asyncio.run(forget_image_impl(cli_ctx, canonical_or_alias, architecture)) @@ -75,11 +76,11 @@ def purge( @click.argument("architecture") @click.pass_obj def set_resource_limit( - cli_ctx, - canonical_or_alias, - slot_type, - range_value, - architecture, + cli_ctx: CLIContext, + canonical_or_alias: str, + slot_type: str, + range_value: tuple[Decimal | None, Decimal | None], + architecture: str, ) -> None: """Set the MIN:MAX values of a SLOT_TYPE limit for the given image REFERENCE.""" asyncio.run( @@ -99,7 +100,7 @@ def set_resource_limit( "-p", "--project", default=None, help="The name of the project to which the images belong." ) @click.pass_obj -def rescan(cli_ctx, registry_or_image: str, project: Optional[str] = None) -> None: +def rescan(cli_ctx: CLIContext, registry_or_image: str, project: Optional[str] = None) -> None: """ Update the kernel image metadata from the configured registries. @@ -115,7 +116,7 @@ def rescan(cli_ctx, registry_or_image: str, project: Optional[str] = None) -> No @click.argument("target") @click.argument("architecture") @click.pass_obj -def alias(cli_ctx, alias, target, architecture) -> None: +def alias(cli_ctx: CLIContext, alias: str, target: str, architecture: str) -> None: """Add an image alias from the given alias to the target image reference.""" asyncio.run(alias_impl(cli_ctx, alias, target, architecture)) @@ -123,7 +124,7 @@ def alias(cli_ctx, alias, target, architecture) -> None: @cli.command() @click.argument("alias") @click.pass_obj -def dealias(cli_ctx, alias) -> None: +def dealias(cli_ctx: CLIContext, alias: str) -> None: """Remove an alias.""" asyncio.run(dealias_impl(cli_ctx, alias)) diff --git a/src/ai/backend/manager/cli/image_impl.py b/src/ai/backend/manager/cli/image_impl.py index f273dfde451..ea0eb223865 100644 --- a/src/ai/backend/manager/cli/image_impl.py +++ b/src/ai/backend/manager/cli/image_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from decimal import Decimal from pprint import pformat, pprint from typing import Any, Optional @@ -23,7 +24,7 @@ log = BraceStyleAdapter(logging.getLogger(__spec__.name)) -async def list_images(cli_ctx: CLIContext, short, installed_only) -> None: +async def list_images(cli_ctx: CLIContext, short: bool, installed_only: bool) -> None: # Connect to postgreSQL DB bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( @@ -75,7 +76,7 @@ async def list_images(cli_ctx: CLIContext, short, installed_only) -> None: log.exception(f"An error occurred. Error: {e}") -async def inspect_image(cli_ctx: CLIContext, canonical_or_alias, architecture) -> None: +async def inspect_image(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( connect_database(bootstrap_config.db) as db, @@ -96,9 +97,10 @@ async def inspect_image(cli_ctx: CLIContext, canonical_or_alias, architecture) - log.exception(f"An error occurred. Error: {e}") -async def forget_image(cli_ctx, canonical_or_alias, architecture) -> None: +async def forget_image(cli_ctx: CLIContext, canonical_or_alias: str, architecture: str) -> None: + bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( - connect_database(cli_ctx.bootstrap_config.db) as db, + connect_database(bootstrap_config.db) as db, db.begin_session() as session, ): try: @@ -146,10 +148,10 @@ async def purge_image( async def set_image_resource_limit( cli_ctx: CLIContext, - canonical_or_alias, - slot_type, - range_value, - architecture, + canonical_or_alias: str, + slot_type: str, + range_value: tuple[Decimal | None, Decimal | None], + architecture: str, ) -> None: bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( @@ -188,7 +190,7 @@ async def rescan_images( log.exception(f"Unknown error occurred. Error: {e}") -async def alias(cli_ctx: CLIContext, alias, target, architecture) -> None: +async def alias(cli_ctx: CLIContext, alias: str, target: str, architecture: str) -> None: bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( connect_database(bootstrap_config.db) as db, @@ -208,7 +210,7 @@ async def alias(cli_ctx: CLIContext, alias, target, architecture) -> None: log.exception(f"An error occurred. Error: {e}") -async def dealias(cli_ctx: CLIContext, alias) -> None: +async def dealias(cli_ctx: CLIContext, alias: str) -> None: bootstrap_config = await cli_ctx.get_bootstrap_config() async with ( connect_database(bootstrap_config.db) as db, diff --git a/src/ai/backend/manager/config/loader/legacy_etcd_loader.py b/src/ai/backend/manager/config/loader/legacy_etcd_loader.py index c3cbc145904..6c5f9b1934c 100644 --- a/src/ai/backend/manager/config/loader/legacy_etcd_loader.py +++ b/src/ai/backend/manager/config/loader/legacy_etcd_loader.py @@ -97,7 +97,7 @@ async def update_resource_slots( if updates: await self._etcd.put_dict(updates) - async def update_manager_status(self, status) -> None: + async def update_manager_status(self, status: ManagerStatus) -> None: await self._etcd.put("manager/status", status.value) self.get_manager_status.cache_clear() diff --git a/src/ai/backend/manager/data/common/types.py b/src/ai/backend/manager/data/common/types.py index fc0b9ddb48a..536080b864a 100644 --- a/src/ai/backend/manager/data/common/types.py +++ b/src/ai/backend/manager/data/common/types.py @@ -19,7 +19,7 @@ class StringFilterData: i_equals: Optional[str] = None i_not_equals: Optional[str] = None - def apply_to_column(self, column) -> Optional[ColumnElement[bool]]: + def apply_to_column(self, column: ColumnElement[str]) -> Optional[ColumnElement[bool]]: """Apply this string filter to a SQLAlchemy column and return the condition. Args: @@ -63,7 +63,7 @@ class IntFilterData: less_than: Optional[int] = None less_than_or_equal: Optional[int] = None - def apply_to_column(self, column) -> Optional[ColumnElement[bool]]: + def apply_to_column(self, column: ColumnElement[int]) -> Optional[ColumnElement[bool]]: """Apply this int filter to a SQLAlchemy column and return the condition. Args: diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index 0aafde7d526..81568d48966 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -105,7 +105,7 @@ def ensure_all_tables_registered() -> None: # helper functions -def zero_if_none(val) -> int: +def zero_if_none(val: int | None) -> int: return 0 if val is None else val @@ -262,7 +262,7 @@ class CurvePublicKeyColumn(TypeDecorator): impl = sa.String cache_ok = True - def load_dialect_impl(self, dialect) -> TypeEngine: + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: return dialect.type_descriptor(sa.String(40)) def process_bind_param( @@ -290,7 +290,7 @@ class QuotaScopeIDType(TypeDecorator): impl = sa.String cache_ok = True - def load_dialect_impl(self, dialect) -> TypeEngine: + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: return dialect.type_descriptor(sa.String(64)) def process_bind_param( @@ -399,11 +399,19 @@ def __init__(self, schema: type[JSONSerializableMixin]) -> None: super().__init__() self._schema = schema - def process_bind_param(self, value, dialect) -> Optional[dict]: - return self._schema.to_json(value) + def process_bind_param( + self, value: JSONSerializableMixin | None, dialect: Dialect + ) -> Optional[dict]: + if value is None: + return None + return self._schema.to_json(value) # type: ignore[arg-type] - def process_result_value(self, value, dialect) -> Optional[JSONSerializableMixin]: - return self._schema.from_json(value) + def process_result_value( + self, value: dict | None, dialect: Dialect + ) -> Optional[JSONSerializableMixin]: + if value is None: + return None + return self._schema.from_json(value) # type: ignore[arg-type] def copy(self, **kw) -> Self: return StructuredJSONObjectColumn(self._schema) # type: ignore[return-value] @@ -422,13 +430,17 @@ def __init__(self, schema: type[JSONSerializableMixin]) -> None: super().__init__() self._schema = schema - def coerce_compared_value(self, op, value) -> JSONB: + def coerce_compared_value(self, op: Any, value: Any) -> JSONB: return JSONB() - def process_bind_param(self, value, dialect) -> list[dict]: - return [self._schema.to_json(item) for item in value] + def process_bind_param( + self, value: list[JSONSerializableMixin] | None, dialect: Dialect + ) -> list[dict]: + return [self._schema.to_json(item) for item in value] if value is not None else [] - def process_result_value(self, value, dialect) -> list[JSONSerializableMixin]: + def process_result_value( + self, value: list | None, dialect: Dialect + ) -> list[JSONSerializableMixin]: if value is None: return [] return [self._schema.from_json(item) for item in value] @@ -487,16 +499,16 @@ def __init__(self, schema: type[TBaseModel]) -> None: super().__init__() self._schema = schema - def coerce_compared_value(self, op, value) -> JSONB: + def coerce_compared_value(self, op: Any, value: Any) -> JSONB: return JSONB() - def process_bind_param(self, value: list[TBaseModel] | None, dialect) -> list: + def process_bind_param(self, value: list[TBaseModel] | None, dialect: Dialect) -> list: # JSONB accepts Python objects directly, not JSON strings if value is not None: return [item.model_dump(mode="json") for item in value] return [] - def process_result_value(self, value: list | str | None, dialect) -> list[TBaseModel]: + def process_result_value(self, value: list | str | None, dialect: Dialect) -> list[TBaseModel]: # JSONB returns already parsed Python objects, not strings # Handle case where value is stored as JSON string (legacy data) if value is not None: @@ -534,16 +546,21 @@ class IPColumn(TypeDecorator): impl = CIDR cache_ok = True - def process_bind_param(self, value, dialect) -> Optional[str]: + def process_bind_param( + self, value: str | ReadableCIDR | None, dialect: Dialect + ) -> Optional[str]: if value is None: return value try: - cidr = ReadableCIDR(value).address + if isinstance(value, str): + cidr = ReadableCIDR(value).address + else: + cidr = value.address except InvalidIpAddressValue as e: raise InvalidAPIParameters(f"{value} is invalid IP address value") from e return cidr - def process_result_value(self, value, dialect) -> Optional[ReadableCIDR]: + def process_result_value(self, value: str | None, dialect: Dialect) -> Optional[ReadableCIDR]: if value is None: return None return ReadableCIDR(value) @@ -645,12 +662,12 @@ class GUID[TUUIDSubType: uuid.UUID](TypeDecorator): uuid_subtype_func: ClassVar[Callable[[Any], uuid.UUID]] = lambda v: v cache_ok = True - def load_dialect_impl(self, dialect) -> TypeEngine: + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) return dialect.type_descriptor(CHAR(16)) - def process_bind_param(self, value: Any | None, dialect) -> str | bytes | None: + def process_bind_param(self, value: Any | None, dialect: Dialect) -> str | bytes | None: # NOTE: EndpointId, SessionId, KernelId are *not* actual types defined as classes, # but a "virtual" type that is an identity function at runtime. # The type checker treats them as distinct derivatives of uuid.UUID. @@ -665,7 +682,7 @@ def process_bind_param(self, value: Any | None, dialect) -> str | bytes | None: return value.bytes return uuid.UUID(value).bytes - def process_result_value(self, value: Any, dialect) -> Optional[TUUIDSubType]: + def process_result_value(self, value: Any, dialect: Dialect) -> Optional[TUUIDSubType]: if value is None: return value cls = type(self) @@ -702,10 +719,10 @@ def __init__( allow_unicode=allow_unicode, ) - def coerce_compared_value(self, op, value) -> Unicode: + def coerce_compared_value(self, op: Any, value: Any) -> Unicode: return Unicode() - def process_bind_param(self, value: Any | None, dialect) -> str | None: + def process_bind_param(self, value: Any | None, dialect: Dialect) -> str | None: if value is None: return value try: @@ -730,29 +747,29 @@ class KernelIDColumnType(GUID[KernelId]): cache_ok = True -def IDColumn(name="id") -> sa.Column: +def IDColumn(name: str = "id") -> sa.Column: return sa.Column(name, GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) -def EndpointIDColumn(name="id") -> sa.Column: +def EndpointIDColumn(name: str = "id") -> sa.Column: return sa.Column( name, EndpointIDColumnType, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) -def SessionIDColumn(name="id") -> sa.Column: +def SessionIDColumn(name: str = "id") -> sa.Column: return sa.Column( name, SessionIDColumnType, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) -def KernelIDColumn(name="id") -> sa.Column: +def KernelIDColumn(name: str = "id") -> sa.Column: return sa.Column( name, KernelIDColumnType, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) -def ForeignKeyIDColumn(name, fk_field, nullable=True) -> sa.Column: +def ForeignKeyIDColumn(name: str, fk_field: str, nullable: bool = True) -> sa.Column: return sa.Column(name, GUID, sa.ForeignKey(fk_field), nullable=nullable) diff --git a/src/ai/backend/manager/models/image/row.py b/src/ai/backend/manager/models/image/row.py index e7166facace..5c6405b9cc0 100644 --- a/src/ai/backend/manager/models/image/row.py +++ b/src/ai/backend/manager/models/image/row.py @@ -392,36 +392,36 @@ class ImageRow(Base): def __init__( self, - name, - project, - architecture, - registry_id, - is_local=False, - registry=None, - image=None, - tag=None, - config_digest=None, - size_bytes=None, - type=None, - accelerators=None, - labels=None, - resources=None, - status=ImageStatus.ALIVE, + name: str, + project: str | None, + architecture: str, + registry_id: UUID, + is_local: bool = False, + registry: str | None = None, + image: str | None = None, + tag: str | None = None, + config_digest: str | None = None, + size_bytes: int | None = None, + type: ImageType | None = None, + accelerators: str | None = None, + labels: dict[str, Any] | None = None, + resources: dict[str, Any] | None = None, + status: ImageStatus = ImageStatus.ALIVE, ) -> None: self.name = name self.project = project - self.registry = registry + self.registry = registry # type: ignore[assignment] self.registry_id = registry_id - self.image = image + self.image = image # type: ignore[assignment] self.tag = tag self.architecture = architecture self.is_local = is_local - self.config_digest = config_digest - self.size_bytes = size_bytes - self.type = type + self.config_digest = config_digest # type: ignore[assignment] + self.size_bytes = size_bytes # type: ignore[assignment] + self.type = type # type: ignore[assignment] self.accelerators = accelerators - self.labels = labels - self._resources = resources + self.labels = labels # type: ignore[assignment] + self._resources = resources # type: ignore[assignment] self.status = status @property @@ -563,7 +563,7 @@ def from_dataclass(cls, image_data: ImageData) -> Self: type=image_data.type, accelerators=image_data.accelerators, labels=image_data.labels.label_data, - resources=image_data.resources.resources_data, + resources={str(k): v for k, v in image_data.resources.resources_data.items()}, status=image_data.status, ) image_row.id = image_data.id diff --git a/src/ai/backend/manager/models/kernel/row.py b/src/ai/backend/manager/models/kernel/row.py index 55f665cd058..7e3d446042e 100644 --- a/src/ai/backend/manager/models/kernel/row.py +++ b/src/ai/backend/manager/models/kernel/row.py @@ -3,7 +3,7 @@ import asyncio import logging import uuid -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager as actxmgr from datetime import datetime, tzinfo from typing import ( @@ -179,7 +179,7 @@ async def get_user_email( return user_email.replace("@", "_") -def default_hostname(context) -> str: +def default_hostname(context: Any) -> str: params = context.get_current_parameters() return f"{params['cluster_role']}{params['cluster_idx']}" @@ -283,8 +283,8 @@ async def handle_kernel_exception( db: ExtendedAsyncSAEngine, op: str, kernel_id: KernelId, - error_callback=None, - cancellation_callback=None, + error_callback: Callable[[], Any] | None = None, + cancellation_callback: Callable[[], Any] | None = None, set_error: bool = False, ) -> AsyncIterator[None]: exc_class = OP_EXC[op] diff --git a/src/ai/backend/manager/models/minilang/__init__.py b/src/ai/backend/manager/models/minilang/__init__.py index c2d7a218963..3bd0984eeab 100644 --- a/src/ai/backend/manager/models/minilang/__init__.py +++ b/src/ai/backend/manager/models/minilang/__init__.py @@ -34,12 +34,13 @@ class EnumFieldItem[TEnum: Enum](NamedTuple): ] -def get_col_from_table(table, column_name: str) -> sa.Column: - try: +def get_col_from_table( + table: sa.Table | sa.sql.Join | type, column_name: str +) -> sa.Column | sa.orm.attributes.InstrumentedAttribute | sa.sql.elements.KeyedColumnElement: + if isinstance(table, (sa.Table, sa.sql.Join)): return table.c[column_name] - except AttributeError: - # For ORM class table - return getattr(table, column_name) + # For ORM class table + return getattr(table, column_name) class ExternalTableFilterSpec: diff --git a/src/ai/backend/manager/models/minilang/ordering.py b/src/ai/backend/manager/models/minilang/ordering.py index 6a3d140578c..944c4db5faf 100644 --- a/src/ai/backend/manager/models/minilang/ordering.py +++ b/src/ai/backend/manager/models/minilang/ordering.py @@ -37,17 +37,21 @@ class OrderDirection(enum.Enum): class OrderingItem(NamedTuple): - column: sa.Column | sa.orm.attributes.InstrumentedAttribute + column: sa.Column | sa.orm.attributes.InstrumentedAttribute | sa.sql.elements.KeyedColumnElement order_direction: OrderDirection class QueryOrderTransformer(Transformer): - def __init__(self, sa_table: sa.Table, column_map: Optional[ColumnMapType] = None) -> None: + def __init__( + self, sa_table: sa.Table | sa.sql.Join | type, column_map: Optional[ColumnMapType] = None + ) -> None: super().__init__() self._sa_table = sa_table self._column_map = column_map - def _get_col(self, col_name: str) -> sa.Column: + def _get_col( + self, col_name: str + ) -> sa.Column | sa.orm.attributes.InstrumentedAttribute | sa.sql.elements.KeyedColumnElement: try: if self._column_map: col_value, func = self._column_map[col_name] @@ -59,10 +63,10 @@ def _get_col(self, col_name: str) -> sa.Column: matched_col = _column.op("->>")(_key) # type: ignore[assignment] case _: raise ValueError("Invalid type of field name", col_name) - col = func(matched_col) if func is not None else matched_col + col = func(matched_col) if func is not None else matched_col # type: ignore[arg-type] else: col = get_col_from_table(self._sa_table, col_name) - return col + return col # type: ignore[return-value] except KeyError as e: raise ValueError("Unknown/unsupported field name", col_name) from e @@ -88,7 +92,9 @@ def __init__(self, column_map: Optional[ColumnMapType] = None) -> None: self._column_map = column_map self._parser = _parser - def parse_order(self, table, order_expr: str) -> list[OrderingItem]: + def parse_order( + self, table: sa.Table | sa.sql.Join | type, order_expr: str + ) -> list[OrderingItem]: try: ast = self._parser.parse(order_expr) return QueryOrderTransformer(table, self._column_map).transform(ast) @@ -105,8 +111,12 @@ def append_ordering( the given SQLAlchemy query object. """ table = sa_query.froms[0] + # FromClause is compatible with our union type, cast for type checker + from typing import cast + + parsed_table = cast(sa.Table | sa.sql.Join | type, table) orders = [ col.asc() if direction == OrderDirection.ASC else col.desc() - for col, direction in self.parse_order(table, order_expr) + for col, direction in self.parse_order(parsed_table, order_expr) ] return sa_query.order_by(*orders) diff --git a/src/ai/backend/manager/models/minilang/queryfilter.py b/src/ai/backend/manager/models/minilang/queryfilter.py index 9a058e78038..be3b209783a 100644 --- a/src/ai/backend/manager/models/minilang/queryfilter.py +++ b/src/ai/backend/manager/models/minilang/queryfilter.py @@ -63,7 +63,9 @@ class QueryFilterTransformer(Transformer): - def __init__(self, sa_table: sa.Table, fieldspec: Optional[FieldSpecType] = None) -> None: + def __init__( + self, sa_table: sa.Table | sa.sql.Join | type, fieldspec: Optional[FieldSpecType] = None + ) -> None: super().__init__() self._sa_table = sa_table self._fieldspec = fieldspec @@ -220,7 +222,7 @@ def __init__(self, fieldspec: Optional[FieldSpecType] = None) -> None: def parse_filter( self, - table, + table: sa.Table | sa.sql.Join | type, filter_expr: str, ) -> WhereClauseType: try: @@ -239,13 +241,17 @@ def append_filter( Parse the given filter expression and build the where clause based on the first target table from the given SQLAlchemy query object. """ + from typing import cast + if isinstance(sa_query, sa.sql.Select): table = sa_query.froms[0] elif isinstance(sa_query, (sa.sql.Delete, sa.sql.Update)): table = sa_query.table else: raise ValueError("Unsupported SQLAlchemy query object type") - where_clause = self.parse_filter(table, filter_expr) + # FromClause is compatible with our union type, cast for type checker + parsed_table = cast(sa.Table | sa.sql.Join | type, table) + where_clause = self.parse_filter(parsed_table, filter_expr) final_query = sa_query.where(where_clause) if final_query is None: raise DataTransformationFailed("Failed to apply filter to query") diff --git a/src/ai/backend/manager/models/network/row.py b/src/ai/backend/manager/models/network/row.py index cfbd11e8f41..44facf674aa 100644 --- a/src/ai/backend/manager/models/network/row.py +++ b/src/ai/backend/manager/models/network/row.py @@ -121,8 +121,8 @@ async def get( cls, session: AsyncSession, network_id: uuid.UUID, - load_project=False, - load_domain=False, + load_project: bool = False, + load_domain: bool = False, ) -> NetworkRow: query = sa.select(NetworkRow).filter(NetworkRow.id == network_id) if load_project: diff --git a/src/ai/backend/manager/models/rbac_models/migration/utils.py b/src/ai/backend/manager/models/rbac_models/migration/utils.py index 50729fa1c94..05f218004b1 100644 --- a/src/ai/backend/manager/models/rbac_models/migration/utils.py +++ b/src/ai/backend/manager/models/rbac_models/migration/utils.py @@ -21,13 +21,15 @@ ) -def insert_if_data_exists(db_conn: Connection, row_type, data: Collection[dict[str, Any]]) -> None: +def insert_if_data_exists( + db_conn: Connection, row_type: sa.Table, data: Collection[dict[str, Any]] +) -> None: if data: db_conn.execute(sa.insert(row_type), list(data)) def insert_skip_on_conflict( - db_conn: Connection, row_type, data: Collection[dict[str, Any]] + db_conn: Connection, row_type: sa.Table, data: Collection[dict[str, Any]] ) -> None: if data: stmt = pg_insert(row_type).values(list(data)).on_conflict_do_nothing() @@ -36,7 +38,7 @@ def insert_skip_on_conflict( def insert_and_returning_id( db_conn: Connection, - row_type, + row_type: sa.Table, data: Any, ) -> uuid.UUID: stmt = sa.insert(row_type).values(data).returning(row_type.c.id) diff --git a/src/ai/backend/manager/models/session/row.py b/src/ai/backend/manager/models/session/row.py index f7647c5a057..35b7725aa9d 100644 --- a/src/ai/backend/manager/models/session/row.py +++ b/src/ai/backend/manager/models/session/row.py @@ -3,7 +3,7 @@ import asyncio import enum import logging -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager as actxmgr from dataclasses import dataclass, field from datetime import datetime @@ -446,8 +446,8 @@ async def handle_session_exception( db: ExtendedAsyncSAEngine, op: str, session_id: SessionId, - error_callback=None, - cancellation_callback=None, + error_callback: Callable[[], Any] | None = None, + cancellation_callback: Callable[[], Any] | None = None, set_error: bool = False, ) -> AsyncIterator[None]: exc_class = OP_EXC[op] @@ -511,14 +511,14 @@ async def handle_session_exception( def _build_session_fetch_query( - base_cond, + base_cond: Any, access_key: AccessKey | None = None, *, allow_stale: bool = True, for_update: bool = False, do_ordering: bool = False, max_matches: Optional[int] = None, - eager_loading_op: Optional[Sequence] = None, + eager_loading_op: Optional[Sequence[_AbstractLoad]] = None, ) -> sa.sql.Select: cond = base_cond if access_key: @@ -1508,7 +1508,7 @@ async def list_sessions( *, allow_stale: bool = False, for_update: bool = False, - kernel_loading_strategy=KernelLoadingStrategy.NONE, + kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE, eager_loading_op: list[Any] | None = None, max_load_count: Optional[int] = None, ) -> Iterable[SessionRow]: @@ -1557,7 +1557,7 @@ async def get_session_by_id( max_matches: int | None = None, allow_stale: bool = True, for_update: bool = False, - eager_loading_op=None, + eager_loading_op: Sequence[_AbstractLoad] | None = None, ) -> SessionRow: sessions = await _match_sessions_by_id( db_session, diff --git a/src/ai/backend/manager/models/session_template.py b/src/ai/backend/manager/models/session_template.py index 2dbebb29501..7405ac01d9a 100644 --- a/src/ai/backend/manager/models/session_template.py +++ b/src/ai/backend/manager/models/session_template.py @@ -149,7 +149,7 @@ async def query_accessible_session_templates( user_role: Optional[UserRole] = None, domain_name: Optional[str] = None, allowed_types: Iterable[str] = ["user"], - extra_conds=None, + extra_conds: Any = None, ) -> list[Mapping[str, Any]]: from .group import association_groups_users as agus from .group import groups diff --git a/src/ai/backend/manager/models/vfolder/row.py b/src/ai/backend/manager/models/vfolder/row.py index 42cedc04e73..bc3cf64ce22 100644 --- a/src/ai/backend/manager/models/vfolder/row.py +++ b/src/ai/backend/manager/models/vfolder/row.py @@ -552,14 +552,14 @@ async def query_accessible_vfolders( user_uuid: uuid.UUID, *, # when enabled, skip vfolder ownership check if user role is admin or superadmin - allow_privileged_access=False, - user_role=None, - domain_name=None, - allowed_vfolder_types=None, - extra_vf_conds=None, - extra_invited_vf_conds=None, - extra_vf_user_conds=None, - extra_vf_group_conds=None, + allow_privileged_access: bool = False, + user_role: UserRole | str | None = None, + domain_name: str | None = None, + allowed_vfolder_types: Sequence[str] | None = None, + extra_vf_conds: Any = None, + extra_invited_vf_conds: Any = None, + extra_vf_user_conds: Any = None, + extra_vf_group_conds: Any = None, allowed_status_set: VFolderStatusSet | None = None, ) -> Sequence[Mapping[str, Any]]: from ai.backend.manager.models.group import association_groups_users as agus @@ -592,7 +592,7 @@ async def query_accessible_vfolders( # users.c.email, ] - async def _append_entries(_query, _is_owner=True) -> None: + async def _append_entries(_query: sa.sql.Select, _is_owner: bool = True) -> None: if extra_vf_conds is not None: _query = _query.where(extra_vf_conds) if extra_vf_user_conds is not None: @@ -746,7 +746,7 @@ async def _append_entries(_query, _is_owner=True) -> None: async def get_allowed_vfolder_hosts_by_group( conn: SAConnection, - resource_policy, + resource_policy: Mapping[str, Any], domain_name: str, group_id: Optional[uuid.UUID] = None, ) -> VFolderHostPermissionMap: @@ -1193,7 +1193,7 @@ async def _update() -> None: async def ensure_host_permission_allowed( - db_conn, + db_conn: SAConnection, folder_host: str, *, permission: VFolderHostPermission, @@ -1222,7 +1222,7 @@ async def ensure_host_permission_allowed( async def filter_host_allowed_permission( - db_conn, + db_conn: SAConnection, *, allowed_vfolder_types: Sequence[str], user_uuid: uuid.UUID, diff --git a/src/ai/backend/manager/openapi.py b/src/ai/backend/manager/openapi.py index 770dfb4f231..bbf4e11b4c3 100644 --- a/src/ai/backend/manager/openapi.py +++ b/src/ai/backend/manager/openapi.py @@ -257,7 +257,7 @@ def extract_api_handler_params( return body_model, query_model, path_model -def generate_openapi(subapps: list[web.Application], verbose=False) -> dict[str, Any]: +def generate_openapi(subapps: list[web.Application], verbose: bool = False) -> dict[str, Any]: openapi: dict[str, Any] = { "openapi": "3.1.0", "info": { diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index d553c55d763..34d753bcf77 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -297,7 +297,7 @@ async def enumerate_instances(self, check_shadow: bool = True) -> list[AgentId]: result = await db_sess.execute(query) return [AgentId(row) for row in result.scalars().all()] - async def update_instance(self, inst_id, updated_fields) -> None: + async def update_instance(self, inst_id: AgentId, updated_fields: dict[str, Any]) -> None: async def _update() -> None: async with self.db.begin() as conn: query = sa.update(agents).values(**updated_fields).where(agents.c.id == inst_id) @@ -368,10 +368,10 @@ async def create_session( config: dict[str, Any], cluster_mode: ClusterMode, cluster_size: int, - dry_run=False, - reuse=False, - enqueue_only=False, - max_wait_seconds=0, + dry_run: bool = False, + reuse: bool = False, + enqueue_only: bool = False, + max_wait_seconds: int = 0, priority: int = SESSION_PRIORITY_DEFAULT, bootstrap_script: Optional[str] = None, dependencies: Optional[list[uuid.UUID]] = None, @@ -668,9 +668,9 @@ async def create_cluster( scaling_group: str, sess_type: SessionTypes, tag: str, - enqueue_only=False, - max_wait_seconds=0, - sudo_session_enabled=False, + enqueue_only: bool = False, + max_wait_seconds: int = 0, + sudo_session_enabled: bool = False, attach_network: uuid.UUID | None = None, ) -> Mapping[str, Any]: resp: MutableMapping[str, Any] = {} @@ -1060,7 +1060,9 @@ async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: "public_key": public_key.decode("utf-8"), } - async def get_user_occupancy(self, user_id, *, db_sess=None) -> ResourceSlot: + async def get_user_occupancy( + self, user_id: uuid.UUID, *, db_sess: Optional[AsyncSession] = None + ) -> ResourceSlot: known_slot_types = await self.config_provider.legacy_etcd_config_loader.get_resource_slots() async def _query() -> ResourceSlot: @@ -1081,7 +1083,9 @@ async def _query() -> ResourceSlot: return await execute_with_retry(_query) - async def get_keypair_occupancy(self, access_key, *, db_sess=None) -> ResourceSlot: + async def get_keypair_occupancy( + self, access_key: AccessKey, *, db_sess: Optional[AsyncSession] = None + ) -> ResourceSlot: known_slot_types = await self.config_provider.legacy_etcd_config_loader.get_resource_slots() async def _query() -> ResourceSlot: @@ -1107,7 +1111,9 @@ async def _query() -> ResourceSlot: return await execute_with_retry(_query) - async def get_domain_occupancy(self, domain_name, *, db_sess=None) -> ResourceSlot: + async def get_domain_occupancy( + self, domain_name: str, *, db_sess: Optional[AsyncSession] = None + ) -> ResourceSlot: # TODO: store domain occupied_slots in Redis? known_slot_types = await self.config_provider.legacy_etcd_config_loader.get_resource_slots() @@ -1135,7 +1141,9 @@ async def _query() -> ResourceSlot: return await execute_with_retry(_query) - async def get_group_occupancy(self, group_id, *, db_sess=None) -> ResourceSlot: + async def get_group_occupancy( + self, group_id: uuid.UUID, *, db_sess: Optional[AsyncSession] = None + ) -> ResourceSlot: # TODO: store domain occupied_slots in Redis? known_slot_types = await self.config_provider.legacy_etcd_config_loader.get_resource_slots() diff --git a/src/ai/backend/manager/repositories/artifact/db_source/db_source.py b/src/ai/backend/manager/repositories/artifact/db_source/db_source.py index 820450263b9..9c26e0fc861 100644 --- a/src/ai/backend/manager/repositories/artifact/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/artifact/db_source/db_source.py @@ -73,13 +73,17 @@ def apply_entity_filters( # Handle StringFilter-based filters if filters.name_filter is not None: - name_condition = filters.name_filter.apply_to_column(ArtifactRow.name) + name_condition = filters.name_filter.apply_to_column( + cast(sa.sql.elements.ColumnElement[str], ArtifactRow.name) + ) if name_condition is not None: conditions.append(name_condition) # Handle registry_filter by joining with registry tables if filters.registry_filter is not None: - registry_condition = filters.registry_filter.apply_to_column(ArtifactRegistryRow.name) + registry_condition = filters.registry_filter.apply_to_column( + cast(sa.sql.elements.ColumnElement[str], ArtifactRegistryRow.name) + ) if registry_condition is not None: # Join with artifact registry table and add condition stmt = stmt.join( @@ -91,7 +95,9 @@ def apply_entity_filters( # Handle source_filter by joining with source registry tables if filters.source_filter is not None: source_registry = sa.orm.aliased(ArtifactRegistryRow) - source_condition = filters.source_filter.apply_to_column(source_registry.name) + source_condition = filters.source_filter.apply_to_column( + cast(sa.sql.elements.ColumnElement[str], source_registry.name) + ) if source_condition is not None: # Join with source registry table (using alias to avoid conflicts) stmt = stmt.join( @@ -164,14 +170,18 @@ def apply_entity_filters( # Handle StringFilter-based version filter if filters.version_filter is not None: version_filter = filters.version_filter.to_dataclass() - version_condition = version_filter.apply_to_column(ArtifactRevisionRow.version) + version_condition = version_filter.apply_to_column( + cast(sa.sql.elements.ColumnElement[str], ArtifactRevisionRow.version) + ) if version_condition is not None: conditions.append(version_condition) # Handle IntFilter-based size filter if filters.size_filter is not None: size_filter = filters.size_filter.to_dataclass() - size_condition = size_filter.apply_to_column(ArtifactRevisionRow.size) + size_condition = size_filter.apply_to_column( + cast(sa.sql.elements.ColumnElement[int], ArtifactRevisionRow.size) + ) if size_condition is not None: conditions.append(size_condition) diff --git a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py index be9daa0f76d..7055cb42d5f 100644 --- a/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/scheduler/db_source/db_source.py @@ -7,9 +7,12 @@ from collections.abc import AsyncIterator, Mapping from contextlib import asynccontextmanager as actxmgr from datetime import datetime -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import UUID +if TYPE_CHECKING: + from ai.backend.manager.clients.storage_proxy.session_manager import StorageSessionManager + import sqlalchemy as sa from sqlalchemy.engine import CursorResult from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection @@ -136,6 +139,7 @@ SessionWithKernels, UserResourcePolicy, ) +from ai.backend.manager.types import UserScope from .types import KeypairConcurrencyData, SessionRowCache @@ -1171,7 +1175,7 @@ async def fetch_session_creation_data( self, spec: SessionCreationSpec, scaling_group_name: str, - storage_manager, + storage_manager: StorageSessionManager, allowed_vfolder_types: list[str], ) -> SessionCreationContext: """ @@ -1362,9 +1366,9 @@ async def query_allowed_scaling_groups( async def _fetch_vfolder_mounts( self, db_sess: SASession, - storage_manager, + storage_manager: StorageSessionManager, allowed_vfolder_types: list[str], - user_scope, + user_scope: UserScope, resource_policy: dict[str, Any], combined_mounts: list[str], combined_mount_map: dict[str | UUID, str], @@ -1391,7 +1395,7 @@ async def _fetch_vfolder_mounts( async def _fetch_dotfiles( self, db_sess: SASession, - user_scope, + user_scope: UserScope, access_key: AccessKey, vfolder_mounts: list, ) -> dict[str, Any]: @@ -1430,9 +1434,9 @@ async def _fetch_user_container_info( async def prepare_vfolder_mounts( self, - storage_manager, + storage_manager: StorageSessionManager, allowed_vfolder_types: list[str], - user_scope, + user_scope: UserScope, resource_policy: dict[str, Any], combined_mounts: list[str], combined_mount_map: dict[str | UUID, str], @@ -1456,7 +1460,7 @@ async def prepare_vfolder_mounts( async def prepare_dotfiles( self, - user_scope, + user_scope: UserScope, access_key: AccessKey, vfolder_mounts: list, ) -> dict[str, Any]: diff --git a/src/ai/backend/manager/repositories/scheduler/repository.py b/src/ai/backend/manager/repositories/scheduler/repository.py index 30324728422..a276454e253 100644 --- a/src/ai/backend/manager/repositories/scheduler/repository.py +++ b/src/ai/backend/manager/repositories/scheduler/repository.py @@ -5,9 +5,12 @@ import logging from collections.abc import Mapping from datetime import datetime -from typing import Optional +from typing import TYPE_CHECKING, Optional from uuid import UUID +if TYPE_CHECKING: + from ai.backend.manager.clients.storage_proxy.session_manager import StorageSessionManager + from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient from ai.backend.common.exception import BackendAIError from ai.backend.common.metrics.metric import DomainType, LayerType @@ -43,6 +46,7 @@ SessionsForStartWithImages, SessionWithKernels, ) +from ai.backend.manager.types import UserScope from .cache_source.cache_source import ScheduleCacheSource from .db_source.db_source import ScheduleDBSource @@ -230,7 +234,7 @@ async def fetch_session_creation_data( self, spec: SessionCreationSpec, scaling_group_name: str, - storage_manager, + storage_manager: StorageSessionManager, allowed_vfolder_types: list[str], ) -> SessionCreationContext: """ @@ -291,9 +295,9 @@ async def query_allowed_scaling_groups( @scheduler_repository_resilience.apply() async def prepare_vfolder_mounts( self, - storage_manager, + storage_manager: StorageSessionManager, allowed_vfolder_types: list[str], - user_scope, + user_scope: UserScope, resource_policy: dict, combined_mounts: list, combined_mount_map: dict, @@ -315,7 +319,7 @@ async def prepare_vfolder_mounts( @scheduler_repository_resilience.apply() async def prepare_dotfiles( self, - user_scope, + user_scope: UserScope, access_key: AccessKey, vfolder_mounts: list[VFolderMount], ) -> dict: diff --git a/src/ai/backend/manager/repositories/session/repository.py b/src/ai/backend/manager/repositories/session/repository.py index 9553a8b0687..63f3ca9699d 100644 --- a/src/ai/backend/manager/repositories/session/repository.py +++ b/src/ai/backend/manager/repositories/session/repository.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import uuid -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast + +if TYPE_CHECKING: + from ai.backend.common.bgtask.reporter import ProgressReporter import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession @@ -369,7 +374,7 @@ async def rescan_images( self, image_canonical: str, registry_project: str, - reporter=None, + reporter: Optional[ProgressReporter] = None, ) -> RescanImagesResult: return await rescan_images( self._db, diff --git a/src/ai/backend/manager/repositories/types.py b/src/ai/backend/manager/repositories/types.py index 30d66000c27..5cfe473fd8c 100644 --- a/src/ai/backend/manager/repositories/types.py +++ b/src/ai/backend/manager/repositories/types.py @@ -121,7 +121,7 @@ def build_lexicographic_cursor_conditions( # Cache subqueries to avoid duplication subquery_cache = {} - def get_cursor_value_subquery(column) -> sa.ScalarSelect: + def get_cursor_value_subquery(column: sa.Column[Any]) -> sa.ScalarSelect[Any]: """Get or create cached subquery for cursor value""" if column not in subquery_cache: id_column = self.model_class.id diff --git a/src/ai/backend/manager/repositories/user/repository.py b/src/ai/backend/manager/repositories/user/repository.py index 102ce72ed56..45b2abdb6f8 100644 --- a/src/ai/backend/manager/repositories/user/repository.py +++ b/src/ai/backend/manager/repositories/user/repository.py @@ -369,7 +369,7 @@ async def _get_user_by_uuid(self, session: SASession, user_uuid: UUID) -> UserRo raise UserNotFound(f"User with UUID {user_uuid} not found.") return res - async def _get_user_by_email_with_conn(self, conn, email: str) -> UserRow: + async def _get_user_by_email_with_conn(self, conn: AsyncConnection, email: str) -> UserRow: """Private method to get user by email using connection.""" result = await conn.execute(sa.select(users).where(users.c.email == email)) res = result.first() @@ -414,7 +414,7 @@ async def _add_user_to_groups( log.info("Adding new user {0} with no groups in domain {1}", user_uuid, domain_name) async def _validate_and_update_main_access_key( - self, conn, email: str, main_access_key: str + self, conn: AsyncConnection, email: str, main_access_key: str ) -> None: """Private method to validate and update main access key.""" session = SASession(conn) @@ -436,7 +436,9 @@ async def _validate_and_update_main_access_key( sa.update(users).where(users.c.email == email).values(main_access_key=main_access_key) ) - async def _sync_keypair_roles(self, conn, user_uuid: UUID, new_role: UserRole) -> None: + async def _sync_keypair_roles( + self, conn: AsyncConnection, user_uuid: UUID, new_role: UserRole + ) -> None: """Private method to sync keypair roles with user role.""" result = await conn.execute( sa.select( @@ -454,6 +456,8 @@ async def _sync_keypair_roles(self, conn, user_uuid: UUID, new_role: UserRole) - if new_role in [UserRole.SUPERADMIN, UserRole.ADMIN]: # User becomes admin - set first keypair as active admin kp = result.first() + if kp is None: + return kp_data = {} if not kp.is_admin: kp_data["is_admin"] = True @@ -492,7 +496,7 @@ async def _sync_keypair_roles(self, conn, user_uuid: UUID, new_role: UserRole) - kp_updates, ) - async def _clear_user_groups(self, conn, user_uuid: UUID) -> None: + async def _clear_user_groups(self, conn: AsyncConnection, user_uuid: UUID) -> None: """Private method to clear user's group associations.""" await conn.execute( sa.delete(association_groups_users).where( @@ -501,7 +505,7 @@ async def _clear_user_groups(self, conn, user_uuid: UUID) -> None: ) async def _update_user_groups( - self, conn, user_uuid: UUID, domain_name: str, group_ids: list[str] + self, conn: AsyncConnection, user_uuid: UUID, domain_name: str, group_ids: list[str] ) -> None: """Private method to update user's group associations.""" # Clear existing groups diff --git a/src/ai/backend/manager/repositories/vfolder/repository.py b/src/ai/backend/manager/repositories/vfolder/repository.py index dc4fcab8047..c949b51c9bd 100644 --- a/src/ai/backend/manager/repositories/vfolder/repository.py +++ b/src/ai/backend/manager/repositories/vfolder/repository.py @@ -1030,8 +1030,10 @@ async def ensure_host_permission_allowed( Ensure that the user has the required permission on the specified vfolder host. """ async with self._db.begin_session() as session: + # Get connection from session + conn = await session.connection() await ensure_host_permission_allowed( - session.bind, + conn, folder_host, permission=permission, allowed_vfolder_types=allowed_vfolder_types, diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index 6c0582daccd..01c2a389dbb 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import load_only, noload -from ai.backend.common.types import ResourceSlot, SessionResult, SessionTypes +from ai.backend.common.types import AccessKey, ResourceSlot, SessionResult, SessionTypes from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.session.types import SessionStatus from ai.backend.manager.models.domain import DomainRow @@ -151,8 +151,11 @@ async def check_keypair_resource_limit( total_keypair_allowed = ResourceSlot.from_policy( resource_policy_map, sched_ctx.known_slot_types ) + + if sess_ctx.access_key is None: + return PredicateResult(False, "Session has no access key") key_occupied = await sched_ctx.registry.get_keypair_occupancy( - sess_ctx.access_key, db_sess=db_sess + AccessKey(sess_ctx.access_key), db_sess=db_sess ) log.debug("keypair:{} current-occupancy: {}", sess_ctx.access_key, key_occupied) log.debug("keypair:{} total-allowed: {}", sess_ctx.access_key, total_keypair_allowed) diff --git a/src/ai/backend/manager/sokovan/deployment/types.py b/src/ai/backend/manager/sokovan/deployment/types.py index f124e70853e..2c66e4ea14e 100644 --- a/src/ai/backend/manager/sokovan/deployment/types.py +++ b/src/ai/backend/manager/sokovan/deployment/types.py @@ -63,14 +63,16 @@ class RouteCreationSpec: # Extension methods for DeploymentInfo compatibility @staticmethod - def get_target_replicas_from_deployment(deployment_info) -> int: + def get_target_replicas_from_deployment(deployment_info: DeploymentInfo) -> int: """Get the target number of replicas for a DeploymentInfo.""" # DeploymentInfo has replica_spec.replica_count return deployment_info.replica_spec.replica_count # Extension methods for DeploymentInfoWithRoutes compatibility @staticmethod - def get_healthy_route_count_from_deployment(deployment_with_routes) -> int: + def get_healthy_route_count_from_deployment( + deployment_with_routes: DeploymentInfoWithRoutes, + ) -> int: """Get the count of healthy routes.""" return sum( 1 diff --git a/src/ai/backend/runner/hash_phrase.py b/src/ai/backend/runner/hash_phrase.py index 896e39d7917..d7d532e6d54 100644 --- a/src/ai/backend/runner/hash_phrase.py +++ b/src/ai/backend/runner/hash_phrase.py @@ -1,21 +1,24 @@ # Based on https://github.com/fpgaminer/hash-phrase/blob/master/hash-phrase.py # but modified to exclude external pbkdf2 implementation +from __future__ import annotations + import binascii import hashlib import json import math import random import sys +from collections.abc import Callable -def pbkdf2_hex(data, salt, iterations, keylen, hashfunc="sha1") -> str: +def pbkdf2_hex(data: str, salt: str, iterations: int, keylen: int, hashfunc: str = "sha1") -> str: dk = hashlib.pbkdf2_hmac( hashfunc, bytes(data, "utf-8"), bytes(salt, "utf-8"), iterations, dklen=keylen ) return binascii.hexlify(dk).decode("utf-8") -def load_dictionary(dictionary_file=None) -> list: +def load_dictionary(dictionary_file: str | None = None) -> list: if dictionary_file is None: dictionary_file = "/opt/kernel/words.json" @@ -23,18 +26,18 @@ def load_dictionary(dictionary_file=None) -> list: return json.load(f) -def default_hasher(data) -> str: +def default_hasher(data: str) -> str: return pbkdf2_hex(data, "", iterations=50000, keylen=32, hashfunc="sha256") def hash_phrase( - data, - minimum_entropy=90, - dictionary=None, - hashfunc=default_hasher, - use_numbers=True, - separator="", - capitalize=True, + data: str, + minimum_entropy: int = 90, + dictionary: list | None = None, + hashfunc: Callable[[str], str] = default_hasher, + use_numbers: bool = True, + separator: str = "", + capitalize: bool = True, ) -> str: if dictionary is None: dictionary = load_dictionary() @@ -44,9 +47,9 @@ def hash_phrase( num_words = math.ceil(minimum_entropy / entropy_per_word) # Hash the data and convert to a big integer (converts as Big Endian) - hash = hashfunc(data) - available_entropy = len(hash) * 4 - hash = int(hash, 16) + hash_str = hashfunc(data) + available_entropy = len(hash_str) * 4 + hash_int = int(hash_str, 16) # Check entropy if num_words * entropy_per_word > available_entropy: @@ -62,8 +65,8 @@ def hash_phrase( word_idx_to_replace = -1 for i in range(num_words): - remainder = int(hash % dict_len) - hash = hash / dict_len + remainder = int(hash_int % dict_len) + hash_int = hash_int // dict_len if i == word_idx_to_replace: phrase.append(str(remainder)) else: diff --git a/src/ai/backend/storage/api/client.py b/src/ai/backend/storage/api/client.py index 58464040692..11c2b922db6 100644 --- a/src/ai/backend/storage/api/client.py +++ b/src/ai/backend/storage/api/client.py @@ -8,7 +8,7 @@ import logging import os import urllib.parse -from collections.abc import AsyncGenerator, Mapping, MutableMapping +from collections.abc import AsyncGenerator, Iterator, Mapping, MutableMapping from contextlib import AbstractAsyncContextManager from datetime import UTC, datetime from http import HTTPStatus @@ -224,10 +224,12 @@ async def download_directory_as_archive( Serve a directory as a zip archive on the fly. """ - def _iter2aiter(iter) -> AsyncGenerator[Any, None]: + def _iter2aiter(iter: Iterator[Any]) -> AsyncGenerator[Any, None]: """Iterable to async iterable""" - def _consume(loop: asyncio.AbstractEventLoop, iter, q: janus.SyncQueue[Any]) -> None: + def _consume( + loop: asyncio.AbstractEventLoop, iter: Iterator[Any], q: janus.SyncQueue[Any] + ) -> None: for item in iter: q.put(item) q.put(SENTINEL) diff --git a/src/ai/backend/storage/volumes/dellemc/onefs_client.py b/src/ai/backend/storage/volumes/dellemc/onefs_client.py index 783be4476e0..635f58a6642 100644 --- a/src/ai/backend/storage/volumes/dellemc/onefs_client.py +++ b/src/ai/backend/storage/volumes/dellemc/onefs_client.py @@ -87,7 +87,7 @@ async def get_list_lnn(self) -> list[int]: data = await resp.json() return data["storagepools"][0]["lnns"] - async def get_node_hardware_info_by_lnn(self, lnn) -> Mapping[str, Any]: + async def get_node_hardware_info_by_lnn(self, lnn: int) -> Mapping[str, Any]: async with self._request("GET", f"cluster/nodes/{lnn}/hardware") as resp: data = await resp.json() node = data["nodes"][0] @@ -98,7 +98,7 @@ async def get_node_hardware_info_by_lnn(self, lnn) -> Mapping[str, Any]: "serial_number": node["serial_number"], } - async def get_node_status_by_lnn(self, lnn) -> Mapping[str, Any]: + async def get_node_status_by_lnn(self, lnn: int) -> Mapping[str, Any]: async with self._request("GET", f"cluster/nodes/{lnn}/status/nvram") as resp: data = await resp.json() node = data["nodes"][0] @@ -206,7 +206,7 @@ async def create_quota( ) as resp: return await resp.json() - async def delete_quota(self, quota_id) -> None: + async def delete_quota(self, quota_id: str) -> None: async with self._request( "DELETE", f"quota/quotas/{quota_id}", diff --git a/src/ai/backend/storage/volumes/gpfs/gpfs_client.py b/src/ai/backend/storage/volumes/gpfs/gpfs_client.py index 6213cef2546..42a9b128ae6 100644 --- a/src/ai/backend/storage/volumes/gpfs/gpfs_client.py +++ b/src/ai/backend/storage/volumes/gpfs/gpfs_client.py @@ -292,7 +292,7 @@ async def create_fileset( path: Optional[Path] = None, owner: Optional[str] = None, permissions: Optional[int] = None, - create_directory=True, + create_directory: bool = True, ) -> None: body: dict[str, Any] = { "filesetName": fileset_name, diff --git a/src/ai/backend/storage/volumes/netapp/netappclient.py b/src/ai/backend/storage/volumes/netapp/netappclient.py index 90e956219fa..3d235a89dd3 100644 --- a/src/ai/backend/storage/volumes/netapp/netappclient.py +++ b/src/ai/backend/storage/volumes/netapp/netappclient.py @@ -601,7 +601,7 @@ async def get_qos_policies(self) -> list[Mapping[str, Any]]: qos_policies.append(policy) return qos_policies - async def get_qos_by_uuid(self, qos_uuid) -> Mapping[str, Any]: + async def get_qos_by_uuid(self, qos_uuid: str) -> Mapping[str, Any]: async with self.send_request( "get", f"/api/storage/qos/policies/{qos_uuid}", @@ -621,7 +621,7 @@ async def get_qos_by_uuid(self, qos_uuid) -> Mapping[str, Any]: "svm": data["svm"], } - async def get_qos_by_volume_id(self, volume_uuid) -> Mapping[str, Any]: + async def get_qos_by_volume_id(self, volume_uuid: str) -> Mapping[str, Any]: async with self.send_request( "get", f"/api/storage/volumes/{volume_uuid}?fields=qos", diff --git a/src/ai/backend/storage/volumes/weka/weka_client.py b/src/ai/backend/storage/volumes/weka/weka_client.py index b108b2e1f25..1096acf727c 100644 --- a/src/ai/backend/storage/volumes/weka/weka_client.py +++ b/src/ai/backend/storage/volumes/weka/weka_client.py @@ -4,7 +4,7 @@ import ssl import time import urllib.parse -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Awaitable, Callable, Iterable, Mapping, MutableMapping from dataclasses import dataclass from datetime import datetime from typing import Any, Optional @@ -91,7 +91,7 @@ def from_json(cls, data: Any) -> WekaFs: ) -def error_handler(inner) -> Any: +def error_handler(inner: Callable[..., Awaitable[Any]]) -> Any: async def outer(*args, **kwargs) -> Any: try: return await inner(*args, **kwargs) diff --git a/src/ai/backend/test/templates/user/user.py b/src/ai/backend/test/templates/user/user.py index 305feb36a19..0909edf9f6d 100644 --- a/src/ai/backend/test/templates/user/user.py +++ b/src/ai/backend/test/templates/user/user.py @@ -81,8 +81,8 @@ async def _create_user_with_keypair( return CreatedUserMeta( email=user_info["email"], password=password, - access_key=keypair_info["access_key"], - secret_key=keypair_info["secret_key"], + access_key=keypair_info["access_key"], # type: ignore[call-overload] + secret_key=keypair_info["secret_key"], # type: ignore[call-overload] ) @override diff --git a/src/ai/backend/test/testcases/session/creation_failure_low_resources.py b/src/ai/backend/test/testcases/session/creation_failure_low_resources.py index b5462953f6f..0af621d89e1 100644 --- a/src/ai/backend/test/testcases/session/creation_failure_low_resources.py +++ b/src/ai/backend/test/testcases/session/creation_failure_low_resources.py @@ -19,7 +19,7 @@ async def test(self) -> None: result = await client_session.Image.get( image_dep.name, image_dep.architecture, fields=[image_fields["labels"]] ) - labels = result["labels"] + labels = result["labels"] # type: ignore[call-overload] min_mem_label = next( filter(lambda label: label["key"] == "ai.backend.resource.min.mem", labels), None ) diff --git a/src/ai/backend/testutils/db.py b/src/ai/backend/testutils/db.py index 555da94b17b..87b97564e20 100644 --- a/src/ai/backend/testutils/db.py +++ b/src/ai/backend/testutils/db.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager -from typing import Protocol +from typing import Any, Protocol from sqlalchemy import Table, text from sqlalchemy.ext.asyncio import AsyncEngine @@ -18,7 +18,7 @@ class HasTable(Protocol): type TableOrORM = Table | type[HasTable] -def _create_tables_sync(conn, tables: list[Table]) -> None: +def _create_tables_sync(conn: Any, tables: list[Table]) -> None: """ Sync function to create tables using MetaData.create_all(). diff --git a/src/ai/backend/testutils/mock.py b/src/ai/backend/testutils/mock.py index 06ecc9dcb24..b9eb0118076 100644 --- a/src/ai/backend/testutils/mock.py +++ b/src/ai/backend/testutils/mock.py @@ -8,14 +8,14 @@ from aioresponses import CallbackResult -def mock_corofunc(return_value) -> Mock: +def mock_corofunc(return_value: Any) -> Mock: """ Return mock coroutine function. Python's default mock module does not support coroutines. """ - async def _mock_corofunc(*args, **kargs) -> Any: + async def _mock_corofunc(*args: Any, **kargs: Any) -> Any: return return_value return mock.Mock(wraps=_mock_corofunc) @@ -41,7 +41,7 @@ class AsyncContextManagerMock: passing `kwargs`. """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.context = kwargs for k, v in kwargs.items(): setattr(self, k, v) @@ -49,7 +49,9 @@ def __init__(self, *args, **kwargs) -> None: async def __aenter__(self) -> "AsyncMock": return AsyncMock(**self.context) - async def __aexit__(self, exc_type, exc_value, exc_tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: Any + ) -> None: pass @@ -62,19 +64,19 @@ class MockableZMQAsyncSock: def create_mock(cls) -> Mock: return mock.Mock(cls()) - def bind(self, addr) -> None: + def bind(self, addr: str) -> None: pass - def connect(self, addr) -> None: + def connect(self, addr: str) -> None: pass def close(self) -> None: pass - async def send(self, frame) -> None: + async def send(self, frame: bytes) -> None: pass - async def send_multipart(self, msg) -> None: + async def send_multipart(self, msg: list[bytes]) -> None: pass async def recv(self) -> None: @@ -112,7 +114,9 @@ class AsyncContextMock(mock.Mock): async def __aenter__(self) -> "AsyncContextMock": return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> None: pass @@ -127,7 +131,9 @@ class AsyncContextMagicMock(mock.MagicMock): async def __aenter__(self) -> "AsyncContextMagicMock": return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> None: pass @@ -159,7 +165,9 @@ class AsyncContextCoroutineMock(AsyncMock): async def __aenter__(self) -> "AsyncContextCoroutineMock": return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + ) -> None: pass @@ -173,7 +181,7 @@ def mock_aioresponses_sequential_payloads( """ cb_call_counter = 0 - def _callback(*args, **kwargs) -> CallbackResult: + def _callback(*args: Any, **kwargs: Any) -> CallbackResult: nonlocal cb_call_counter if cb_call_counter >= len(mock_responses): @@ -187,7 +195,7 @@ def _callback(*args, **kwargs) -> CallbackResult: def setup_dockerhub_mocking( - aiohttp_request_mock, registry_url: str, dockerhub_responses_mock: dict[str, Any] + aiohttp_request_mock: Any, registry_url: str, dockerhub_responses_mock: dict[str, Any] ) -> None: # /v2/ endpoint aiohttp_request_mock.get( diff --git a/src/ai/backend/web/logging.py b/src/ai/backend/web/logging.py index 445301b87df..9cbf05f013a 100644 --- a/src/ai/backend/web/logging.py +++ b/src/ai/backend/web/logging.py @@ -1,12 +1,15 @@ from __future__ import annotations import logging +from collections.abc import Mapping +from types import TracebackType +from typing import Any class BraceMessage: __slots__ = ("args", "fmt") - def __init__(self, fmt, args) -> None: + def __init__(self, fmt: str, args: tuple[Any, ...]) -> None: self.fmt = fmt self.args = args @@ -15,10 +18,26 @@ def __str__(self) -> str: class BraceStyleAdapter(logging.LoggerAdapter): - def __init__(self, logger, extra=None) -> None: + def __init__(self, logger: logging.Logger, extra: Mapping[str, Any] | None = None) -> None: super().__init__(logger, extra) - def log(self, level, msg, *args, **kwargs) -> None: + def log( + self, + level: int, + msg: object, + *args: object, + exc_info: ( + bool + | tuple[type[BaseException], BaseException, TracebackType | None] + | tuple[None, None, None] + | BaseException + | None + ) = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + **kwargs: object, + ) -> None: if self.isEnabledFor(level): msg, processed_kwargs = self.process(msg, kwargs) - self.logger._log(level, BraceMessage(msg, args), (), **processed_kwargs) + self.logger._log(level, BraceMessage(str(msg), args), (), **processed_kwargs) diff --git a/src/ai/backend/web/proxy.py b/src/ai/backend/web/proxy.py index 0c2e6ef6c91..05f7aff3677 100644 --- a/src/ai/backend/web/proxy.py +++ b/src/ai/backend/web/proxy.py @@ -5,7 +5,7 @@ import json import logging import random -from collections.abc import Iterable +from collections.abc import Awaitable, Callable, Iterable from typing import Final, Optional, cast import aiohttp @@ -150,7 +150,10 @@ def _decrypt_payload(endpoint: str, payload: bytes) -> bytes: @web.middleware -async def decrypt_payload(request: web.Request, handler) -> web.StreamResponse: +async def decrypt_payload( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> web.StreamResponse: config: WebServerUnifiedConfig = request.app["config"] try: request_headers = extra_config_headers.check(request.headers) @@ -566,7 +569,7 @@ async def web_plugin_handler( async def websocket_handler( request: web.Request, *, - is_anonymous=False, + is_anonymous: bool = False, api_endpoint: Optional[str] = None, jwt_token: Optional[str] = None, ) -> web.StreamResponse: diff --git a/src/ai/backend/web/stats.py b/src/ai/backend/web/stats.py index dbdaab408b6..4424644d977 100644 --- a/src/ai/backend/web/stats.py +++ b/src/ai/backend/web/stats.py @@ -2,6 +2,7 @@ import asyncio import weakref +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from aiohttp import web @@ -26,7 +27,10 @@ class WebStats: @web.middleware -async def track_active_handlers(request: web.Request, handler) -> web.StreamResponse: +async def track_active_handlers( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> web.StreamResponse: stats: WebStats = request.app["stats"] stats.active_handlers.add(asyncio.current_task()) # type: ignore return await handler(request) diff --git a/tests/unit/manager/models/gql_models/BUILD b/tests/unit/manager/models/gql_models/BUILD deleted file mode 100644 index 45a9bd6b95f..00000000000 --- a/tests/unit/manager/models/gql_models/BUILD +++ /dev/null @@ -1,5 +0,0 @@ -python_test_utils() - -python_tests( - name="tests", -) diff --git a/tests/unit/manager/models/gql_models/test_service_config.py b/tests/unit/manager/models/gql_models/test_service_config.py deleted file mode 100644 index e637b40a80f..00000000000 --- a/tests/unit/manager/models/gql_models/test_service_config.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from unittest.mock import MagicMock - -import pytest - -from ai.backend.manager.api.gql_legacy.service_config import ServiceConfigNode - - -@pytest.mark.asyncio -async def test_service_config_node_load_returns_dict_not_tuple() -> None: - """Regression test for BA-3595: unified_config should be dict, not tuple.""" - # Mock ResolveInfo and context - mock_info = MagicMock() - mock_config = MagicMock() - mock_config.model_dump.return_value = {"key": "value"} - mock_config.model_json_schema.return_value = {"type": "object"} - mock_info.context.config_provider.config = mock_config - - # Call the method - result = await ServiceConfigNode.load(mock_info, "manager") - - # Verify configuration is dict, not tuple (BA-3595 regression) - assert isinstance(result.configuration, dict), ( - f"configuration should be dict, got {type(result.configuration)}" - ) - assert result.configuration == {"key": "value"} - assert result.schema == {"type": "object"} - assert result.service == "manager" diff --git a/tests/unit/manager/notification/BUILD b/tests/unit/manager/notification/BUILD deleted file mode 100644 index 57341b1358b..00000000000 --- a/tests/unit/manager/notification/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -python_tests( - name="tests", -) diff --git a/tests/unit/manager/notification/__init__.py b/tests/unit/manager/notification/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/unit/manager/notification/channels/BUILD b/tests/unit/manager/notification/channels/BUILD deleted file mode 100644 index 57341b1358b..00000000000 --- a/tests/unit/manager/notification/channels/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -python_tests( - name="tests", -) diff --git a/tests/unit/manager/notification/channels/__init__.py b/tests/unit/manager/notification/channels/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/unit/manager/notification/channels/email/BUILD b/tests/unit/manager/notification/channels/email/BUILD deleted file mode 100644 index 57341b1358b..00000000000 --- a/tests/unit/manager/notification/channels/email/BUILD +++ /dev/null @@ -1,3 +0,0 @@ -python_tests( - name="tests", -) diff --git a/tests/unit/manager/notification/channels/email/__init__.py b/tests/unit/manager/notification/channels/email/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/unit/manager/notification/channels/email/test_channel.py b/tests/unit/manager/notification/channels/email/test_channel.py deleted file mode 100644 index 50027bad078..00000000000 --- a/tests/unit/manager/notification/channels/email/test_channel.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Unit tests for EmailChannel.""" - -from __future__ import annotations - -import smtplib -from collections.abc import Generator -from unittest.mock import MagicMock, patch - -import pytest - -from ai.backend.common.data.notification.types import ( - EmailMessage, - EmailSpec, - SMTPAuth, - SMTPConnection, -) -from ai.backend.manager.errors.notification import NotificationProcessingFailure -from ai.backend.manager.notification.channels.email.channel import EmailChannel -from ai.backend.manager.notification.types import NotificationMessage, SendResult - - -class TestEmailChannel: - """Test cases for EmailChannel.""" - - @pytest.fixture - def mock_smtp(self) -> Generator[MagicMock, None, None]: - """ - Provide a mocked smtplib.SMTP context manager. - - The mock SMTP server has all methods (starttls, login, send_message) - set up as successful MagicMocks by default. - """ - with patch("ai.backend.manager.notification.channels.email.channel.smtplib") as mock_module: - mock_server = MagicMock() - mock_module.SMTP.return_value.__enter__ = MagicMock(return_value=mock_server) - mock_module.SMTP.return_value.__exit__ = MagicMock(return_value=False) - - # Default success behavior - mock_server.starttls = MagicMock() - mock_server.login = MagicMock() - mock_server.send_message = MagicMock() - - # Preserve real exception classes for error handling tests - mock_module.SMTPConnectError = smtplib.SMTPConnectError - mock_module.SMTPAuthenticationError = smtplib.SMTPAuthenticationError - mock_module.SMTPException = smtplib.SMTPException - - yield mock_module - - @pytest.fixture - def basic_spec(self) -> EmailSpec: - """Basic email specification with authentication.""" - return EmailSpec( - smtp=SMTPConnection( - host="smtp.example.com", - port=587, - ), - message=EmailMessage( - from_email="noreply@example.com", - to_emails=["admin@example.com"], - subject_template="Test Notification", - ), - auth=SMTPAuth( - username="user@example.com", - password="password123", - ), - ) - - @pytest.fixture - def no_auth_spec(self) -> EmailSpec: - """Specification without authentication (relay server).""" - return EmailSpec( - smtp=SMTPConnection( - host="smtp.example.com", - port=25, - use_tls=False, - ), - message=EmailMessage( - from_email="noreply@example.com", - to_emails=["admin@example.com"], - ), - ) - - @pytest.fixture - def multi_recipient_spec(self) -> EmailSpec: - """Specification with multiple recipients.""" - return EmailSpec( - smtp=SMTPConnection( - host="smtp.example.com", - port=587, - ), - message=EmailMessage( - from_email="noreply@example.com", - to_emails=["admin1@example.com", "admin2@example.com", "admin3@example.com"], - ), - auth=SMTPAuth( - username="user@example.com", - password="password123", - ), - ) - - @pytest.fixture - def no_subject_spec(self) -> EmailSpec: - """Specification without subject template (uses message first line).""" - return EmailSpec( - smtp=SMTPConnection( - host="smtp.example.com", - port=587, - ), - message=EmailMessage( - from_email="noreply@example.com", - to_emails=["admin@example.com"], - subject_template=None, - ), - auth=SMTPAuth( - username="user@example.com", - password="password123", - ), - ) - - # ============================================================ - # Tests: Success Cases - # ============================================================ - - @pytest.mark.asyncio - async def test_send_success_with_auth( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test successful email sending with authentication.""" - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test notification message") - - result = await channel.send(message) - - assert isinstance(result, SendResult) - # Verify SMTP was initialized with correct parameters - mock_smtp.SMTP.assert_called_once_with( - "smtp.example.com", 587, timeout=basic_spec.smtp.timeout - ) - # Get the server instance from context manager - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - mock_server.starttls.assert_called_once() - mock_server.login.assert_called_once_with("user@example.com", "password123") - mock_server.send_message.assert_called_once() - - @pytest.mark.asyncio - async def test_send_success_without_auth( - self, - mock_smtp: MagicMock, - no_auth_spec: EmailSpec, - ) -> None: - """Test successful email sending without authentication (relay server).""" - channel = EmailChannel(email_spec=no_auth_spec) - message = NotificationMessage(message="Test relay message") - - result = await channel.send(message) - - assert isinstance(result, SendResult) - # Get the server instance from context manager - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - # No TLS and no login for relay server - mock_server.starttls.assert_not_called() - mock_server.login.assert_not_called() - mock_server.send_message.assert_called_once() - - @pytest.mark.asyncio - async def test_send_to_multiple_recipients( - self, - mock_smtp: MagicMock, - multi_recipient_spec: EmailSpec, - ) -> None: - """Test email is sent to all recipients.""" - channel = EmailChannel(email_spec=multi_recipient_spec) - message = NotificationMessage(message="Test multi-recipient message") - - await channel.send(message) - - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - sent_msg = mock_server.send_message.call_args[0][0] - assert "admin1@example.com" in sent_msg["To"] - assert "admin2@example.com" in sent_msg["To"] - assert "admin3@example.com" in sent_msg["To"] - - @pytest.mark.asyncio - async def test_subject_from_config_template( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test email uses subject_template from spec.""" - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test message body") - - await channel.send(message) - - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - sent_msg = mock_server.send_message.call_args[0][0] - # basic_config has subject_template="Test Notification" - assert sent_msg["Subject"] == "Test Notification" - - @pytest.mark.asyncio - async def test_subject_defaults_to_message_first_line( - self, - mock_smtp: MagicMock, - no_subject_spec: EmailSpec, - ) -> None: - """Test subject defaults to first line of message when not provided.""" - channel = EmailChannel(email_spec=no_subject_spec) - # No subject provided, should use first line of message - message = NotificationMessage(message="First Line Subject\nRest of the message body") - - await channel.send(message) - - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - sent_msg = mock_server.send_message.call_args[0][0] - assert sent_msg["Subject"] == "First Line Subject" - - # ============================================================ - # Tests: Error Cases - # ============================================================ - - @pytest.mark.asyncio - async def test_connection_error_raises_failure( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test connection failure raises NotificationProcessingFailure.""" - mock_smtp.SMTP.return_value.__enter__.side_effect = smtplib.SMTPConnectError( - 421, "Connection refused" - ) - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test message") - - with pytest.raises(NotificationProcessingFailure): - await channel.send(message) - - @pytest.mark.asyncio - async def test_auth_error_raises_failure( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test authentication failure raises NotificationProcessingFailure.""" - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - mock_server.login.side_effect = smtplib.SMTPAuthenticationError( - 535, "Authentication failed" - ) - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test message") - - with pytest.raises(NotificationProcessingFailure): - await channel.send(message) - - @pytest.mark.asyncio - async def test_smtp_error_raises_failure( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test SMTP error raises NotificationProcessingFailure.""" - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - mock_server.send_message.side_effect = smtplib.SMTPException("SMTP error") - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test message") - - with pytest.raises(NotificationProcessingFailure): - await channel.send(message) - - @pytest.mark.asyncio - async def test_use_tls_option( - self, - mock_smtp: MagicMock, - basic_spec: EmailSpec, - ) -> None: - """Test use_tls option controls STARTTLS.""" - # Test with TLS enabled (default) - channel = EmailChannel(email_spec=basic_spec) - message = NotificationMessage(message="Test message") - - await channel.send(message) - - mock_server = mock_smtp.SMTP.return_value.__enter__.return_value - mock_server.starttls.assert_called_once() - - @pytest.mark.asyncio - async def test_timeout_passed_to_smtp( - self, - mock_smtp: MagicMock, - ) -> None: - """Test timeout is passed to SMTP constructor.""" - spec = EmailSpec( - smtp=SMTPConnection( - host="smtp.example.com", - port=587, - timeout=120, - ), - message=EmailMessage( - from_email="noreply@example.com", - to_emails=["admin@example.com"], - ), - ) - channel = EmailChannel(email_spec=spec) - message = NotificationMessage(message="Test message") - - await channel.send(message) - - mock_smtp.SMTP.assert_called_once_with("smtp.example.com", 587, timeout=120) diff --git a/tools/build-macros.py b/tools/build-macros.py index 537b59318a1..5a1afab4a05 100644 --- a/tools/build-macros.py +++ b/tools/build-macros.py @@ -1,4 +1,4 @@ -def visibility_private_component(**kwargs) -> None: +def visibility_private_component(**kwargs: object) -> None: """Private package not expected to be imported by anything else than itself.""" allowed_dependencies = kwargs.get("allowed_dependencies", []) allowed_dependents = kwargs.get("allowed_dependents", []) @@ -25,7 +25,9 @@ def visibility_private_component(**kwargs) -> None: ) -def common_scie_config(build_style, *, entry_point="ai.backend.cli.__main__") -> dict[str, object]: +def common_scie_config( + build_style: str, *, entry_point: str = "ai.backend.cli.__main__" +) -> dict[str, object]: build_style_to_tag = { "lazy": "lazy", "eager": "fat",