Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8388.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable ANN001 linter rule - Add type annotations to function arguments
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,8 @@ ignore = [
"A001", # builtin-variable-shadowing - mostly id/filter/input in GraphQL (75 violations)
"A002", # builtin-argument-shadowing - mostly id/filter/input in GraphQL (297 violations)
# ANN (annotations) - defer for gradual migration, mypy handles type checking
"ANN001", # missing-type-function-argument - 1336 violations, defer for gradual migration
"ANN002", # missing-type-args - *args type annotation, low practical value
"ANN003", # missing-type-kwargs - **kwargs type annotation, low practical value
# "ANN201", # missing-return-type-undocumented-public-function - ENABLING: analyzing patterns
# "ANN202", # missing-return-type-private-function - ENABLED: fixing violations incrementally
# "ANN205", # missing-return-type-static-method - ENABLED: fixing violations incrementally
# "ANN206", # missing-return-type-class-method - ENABLED: fixing violations incrementally
"ANN401", # dynamically-typed-expression - Any needed for GraphQL resolvers, dynamic APIs
]

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/cuda_open/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _ensure_lib(cls) -> None:
raise ImportError(f"Could not load the {cls.name} library!")

@classmethod
def invoke(cls, func_name, *args, check_rc=True) -> int:
def invoke(cls, func_name: str, *args: Any, check_rc: bool = True) -> int:
try:
cls._ensure_lib()
except ImportError:
Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/accelerator/ipu/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from pathlib import Path
from zipfile import ZipInfo

from Cython.Build import cythonize
from setuptools import setup
Expand All @@ -15,7 +16,9 @@
)


def _filtered_writestr(self, zinfo_or_arcname, bytes, compress_type=None) -> None:
def _filtered_writestr(
self: WheelFile, zinfo_or_arcname: ZipInfo | str, bytes: bytes, compress_type: int | None = None
) -> None:
global exclude_source_files
if exclude_source_files:
if isinstance(zinfo_or_arcname, str):
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/rebellions/atom/pci.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path


async def read_sysfs(path, attr) -> str:
async def read_sysfs(path: Path, attr: str) -> str:
def _blocking() -> str:
return (path / attr).read_text().strip()

Expand Down
8 changes: 4 additions & 4 deletions src/ai/backend/accelerator/rocm/exception.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
class NoRocmDeviceError(Exception):
def __init__(self, message) -> None:
def __init__(self, message: str) -> None:
self.message = message


class GenericRocmError(Exception):
def __init__(self, message) -> None:
def __init__(self, message: str) -> None:
self.message = message


class RocmUtilFetchError(Exception):
def __init__(self, message) -> None:
def __init__(self, message: str) -> None:
self.message = message


class RocmMemFetchError(Exception):
def __init__(self, message) -> None:
def __init__(self, message: str) -> None:
self.message = message


Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/tenstorrent/n300/pci.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path


async def read_sysfs(path, attr) -> str:
async def read_sysfs(path: Path, attr: str) -> str:
def _blocking() -> str:
return (path / attr).read_text().strip()

Expand Down
5 changes: 4 additions & 1 deletion src/ai/backend/accelerator/tpu/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DiscretePropertyAllocMap,
)
from ai.backend.agent.stats import ContainerMeasurement, NodeMeasurement, StatContext
from ai.backend.agent.types import Container
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import (
BinarySize,
Expand Down Expand Up @@ -194,7 +195,9 @@ async def get_attached_devices(
})
return attached_devices

async def restore_from_container(self, container, alloc_map) -> None:
async def restore_from_container(
self, container: Container, alloc_map: AbstractAllocMap
) -> None:
# TODO: implement!
pass

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/tpu/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class libtpu:
zone: ClassVar[Optional[str]] = None

@classmethod
async def _run_ctpu(cls, cmd) -> str:
async def _run_ctpu(cls, cmd: list[str]) -> str:
if not cls.zone:
try:
proc = await subprocess.create_subprocess_exec(
Expand Down
12 changes: 6 additions & 6 deletions src/ai/backend/account_manager/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def auth_required(handler: Handler) -> Handler:
@functools.wraps(handler)
async def wrapped(request, *args, **kwargs) -> web.StreamResponse:
async def wrapped(request: web.Request, *args, **kwargs) -> web.StreamResponse:
if request.get("is_authorized", False):
return await handler(request, *args, **kwargs)
raise AuthorizationFailed("Unauthorized access")
Expand All @@ -28,15 +28,15 @@ async def wrapped(request, *args, **kwargs) -> web.StreamResponse:
return wrapped


def set_handler_attr(func, key, value) -> None:
def set_handler_attr(func: Callable, key: str, value: Any) -> None:
attrs = getattr(func, "_backend_attrs", None)
if attrs is None:
attrs = {}
attrs[key] = value
func._backend_attrs = attrs
func._backend_attrs = attrs # type: ignore[attr-defined]


def get_handler_attr(request, key, default=None) -> Any:
def get_handler_attr(request: web.Request, key: str, default: Any = None) -> Any:
# When used in the aiohttp server-side codes, we should use
# request.match_info.hanlder instead of handler passed to the middleware
# functions because aiohttp wraps this original handler with functools.partial
Expand Down Expand Up @@ -116,7 +116,7 @@ def ensure_stream_response_type[TAnyResponse: web.StreamResponse](

def pydantic_api_response_handler(
handler: THandlerFuncWithoutParam,
is_deprecated=False,
is_deprecated: bool = False,
) -> Handler:
"""
Only for API handlers which does not require request body.
Expand All @@ -141,7 +141,7 @@ def pydantic_api_handler[TParamModel: BaseModel, TQueryModel: BaseModel](
checker: type[TParamModel],
loads: Callable[[str], Any] | None = None,
query_param_checker: type[TQueryModel] | None = None,
is_deprecated=False,
is_deprecated: bool = False,
) -> Callable[[THandlerFuncWithParam], Handler]:
def wrap(
handler: THandlerFuncWithParam,
Expand Down
14 changes: 7 additions & 7 deletions src/ai/backend/account_manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ class GUID[UUID_SubType: uuid.UUID](TypeDecorator):
uuid_subtype_func: ClassVar[Callable[[Any], Any]] = lambda v: v
cache_ok = True

def load_dialect_impl(self, dialect) -> TypeDecorator:
def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
return dialect.type_descriptor(CHAR(16))
return cast(TypeDecorator, dialect.type_descriptor(UUID()))
return cast(TypeDecorator, dialect.type_descriptor(CHAR(16)))

def process_bind_param(
self, value: UUID_SubType | uuid.UUID | None, dialect
self, value: UUID_SubType | uuid.UUID | None, dialect: Dialect
) -> str | bytes | None:
# NOTE: EndpointId, SessionId, KernelId are *not* actual types defined as classes,
# but a "virtual" type that is an identity function at runtime.
Expand All @@ -82,7 +82,7 @@ def process_bind_param(
case _:
return uuid.UUID(value).bytes

def process_result_value(self, value: Any, dialect) -> UUID_SubType | None:
def process_result_value(self, value: Any, dialect: Dialect) -> UUID_SubType | None:
if value is None:
return value
cls = type(self)
Expand Down Expand Up @@ -134,9 +134,9 @@ def python_type(self) -> type[T_StrEnum]:
class PasswordColumn(TypeDecorator):
impl = VARCHAR

def process_bind_param(self, value, dialect) -> str:
def process_bind_param(self, value: Any, dialect: Dialect) -> str:
return hash_password(value)


def IDColumn(name="id") -> sa.Column:
def IDColumn(name: str = "id") -> sa.Column:
return sa.Column(name, GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
18 changes: 9 additions & 9 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ async def generate_accelerator_mounts(
raise NotImplementedError

@abstractmethod
def resolve_krunner_filepath(self, filename) -> Path:
def resolve_krunner_filepath(self, filename: str) -> Path:
"""
Return matching krunner path object for given filename.
"""
Expand All @@ -502,7 +502,7 @@ async def prepare_container(
self,
resource_spec: KernelResourceSpec,
environ: Mapping[str, str],
service_ports,
service_ports: list[ServicePort],
cluster_info: ClusterInfo,
) -> KernelObjectType:
raise NotImplementedError
Expand All @@ -512,8 +512,8 @@ async def start_container(
self,
kernel_obj: AbstractKernel,
cmdargs: list[str],
resource_opts,
preopen_ports,
resource_opts: Optional[Mapping[str, Any]],
preopen_ports: list[int],
cluster_info: ClusterInfo,
) -> Mapping[str, Any]:
raise NotImplementedError
Expand Down Expand Up @@ -582,9 +582,9 @@ async def mount_krunner(
environ: MutableMapping[str, str],
) -> None:
def _mount(
type,
src,
dst,
type: MountTypes,
src: str | Path,
dst: str | Path,
) -> None:
resource_spec.mounts.append(
self.get_runner_mount(
Expand Down Expand Up @@ -3687,7 +3687,7 @@ async def shutdown_service(self, kernel_id: KernelId, service: str) -> None:

async def commit(
self,
reporter,
reporter: Any,
kernel_id: KernelId,
subdir: str,
*,
Expand Down Expand Up @@ -3719,7 +3719,7 @@ async def list_files(self, kernel_id: KernelId, path: str) -> dict[str, Any]:
async def ping_kernel(self, kernel_id: KernelId) -> dict[str, float] | None:
return await self.kernel_registry[kernel_id].ping()

async def save_last_registry(self, force=False) -> None:
async def save_last_registry(self, force: bool = False) -> None:
await self._write_kernel_registry_to_recovery(
self.kernel_registry,
KernelRegistrySaveMetadata(force),
Expand Down
14 changes: 7 additions & 7 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import zmq
import zmq.asyncio
from aiodocker.docker import Docker, DockerContainer
from aiodocker.exceptions import DockerError
from aiodocker.exceptions import DockerContainerError, DockerError
from aiodocker.types import PortInfo
from aiomonitor.task import preserve_termination_log
from aiotools import TaskGroup
Expand Down Expand Up @@ -192,7 +192,7 @@
)


async def get_extra_volumes(docker, lang) -> list[VolumeInfo]:
async def get_extra_volumes(docker: Docker, lang: str) -> list[VolumeInfo]:
avail_volumes = (await docker.volumes.list())["Volumes"]
if not avail_volumes:
return []
Expand Down Expand Up @@ -265,14 +265,14 @@ async def _clean_scratch(
pass


def _DockerError_reduce(self) -> tuple[type, tuple[Any, ...]]:
def _DockerError_reduce(self: DockerError) -> tuple[type, tuple[Any, ...]]:
return (
type(self),
(self.status, {"message": self.message}, *self.args),
)


def _DockerContainerError_reduce(self) -> tuple[type, tuple[Any, ...]]:
def _DockerContainerError_reduce(self: DockerContainerError) -> tuple[type, tuple[Any, ...]]:
return (
type(self),
(self.status, {"message": self.message}, self.container_id, *self.args),
Expand Down Expand Up @@ -644,7 +644,7 @@ async def get_intrinsic_mounts(self) -> Sequence[Mount]:
return mounts

@override
def resolve_krunner_filepath(self, filename) -> Path:
def resolve_krunner_filepath(self, filename: str) -> Path:
return Path(
pkg_resources.resource_filename(
"ai.backend.runner",
Expand Down Expand Up @@ -1014,8 +1014,8 @@ async def start_container(
self,
kernel_obj: AbstractKernel,
cmdargs: list[str],
resource_opts,
preopen_ports,
resource_opts: Optional[Mapping[str, Any]],
preopen_ports: list[int],
cluster_info: ClusterInfo,
) -> Mapping[str, Any]:
loop = current_loop()
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/docker/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
log.info("Automatic ~/.output file S3 uploads is disabled.")


def relpath(path, base) -> Path:
def relpath(path: Path | str, base: Path | str) -> Path:
return Path(path).resolve().relative_to(Path(base).resolve())


Expand Down Expand Up @@ -52,7 +52,7 @@ def scandir(root: Path, allowed_max_size: int) -> dict[Path, float]:
return file_stats


def diff_file_stats(fs1, fs2) -> set[Path]:
def diff_file_stats(fs1: dict[Path, float], fs2: dict[Path, float]) -> set[Path]:
k2 = set(fs2.keys())
k1 = set(fs1.keys())
new_files = k2 - k1
Expand Down
14 changes: 7 additions & 7 deletions src/ai/backend/agent/docker/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]:
async def generate_docker_args(
self,
docker: Docker,
device_alloc,
device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]],
) -> Mapping[str, Any]:
cores = [*map(int, device_alloc["cpu"].keys())]
cores = [*map(int, device_alloc[SlotName("cpu")].keys())]
sorted_core_ids = [*map(str, sorted(cores))]
return {
"HostConfig": {
Expand Down Expand Up @@ -933,9 +933,9 @@ async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]:
async def generate_docker_args(
self,
docker: Docker,
device_alloc,
device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]],
) -> Mapping[str, Any]:
memory = sum(device_alloc["mem"].values())
memory = sum(device_alloc[SlotName("mem")].values())
return {
"HostConfig": {
"MemorySwap": int(memory), # prevent using swap!
Expand Down Expand Up @@ -1046,7 +1046,7 @@ async def get_capabilities(self) -> set[ContainerNetworkCapability]:
return {ContainerNetworkCapability.GLOBAL}

async def join_network(
self, kernel_config: KernelCreationConfig, cluster_info: ClusterInfo, **kwargs
self, kernel_config: KernelCreationConfig, cluster_info: ClusterInfo, **kwargs: Any
) -> dict[str, Any]:
if _cluster_ssh_port_mapping := cluster_info.get("cluster_ssh_port_mapping"):
return {
Expand All @@ -1068,7 +1068,7 @@ async def leave_network(self, kernel: DockerKernel) -> None:
pass

async def prepare_port_forward(
self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs
self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs: Any
) -> None:
host_ports = [p[0] for p in ports]
scratch_dir = (
Expand All @@ -1091,7 +1091,7 @@ async def prepare_port_forward(
)

async def expose_ports(
self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs
self, kernel: DockerKernel, bind_host: str, ports: Iterable[tuple[int, int]], **kwargs: Any
) -> ContainerNetworkInfo:
host_ports = [p[0] for p in ports]

Expand Down
Loading
Loading