diff --git a/setup.cfg b/setup.cfg index 9f25c5082..60d243438 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = aiohttp~=3.8.0 aiohttp_cors~=0.7 aiohttp_sse~=2.0 + aiolimiter~=1.0 aiomonitor~=0.4.5 aioredis[hiredis]~=2.0 aiotools~=1.4.0 diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index 02dd1f3e7..665a56482 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -4,9 +4,22 @@ from contextvars import ContextVar import logging import json -from typing import Any, AsyncIterator, Dict, Mapping, Optional, TYPE_CHECKING, cast +from pathlib import Path +import pickle +import tempfile +from typing import ( + Any, + AsyncIterator, + ClassVar, + Dict, + Mapping, + Optional, + TYPE_CHECKING, + cast, +) import aiohttp +from aiolimiter import AsyncLimiter import aiotools import yarl @@ -27,9 +40,42 @@ log = BraceStyleAdapter(logging.getLogger(__name__)) +_limiters: dict[str, AsyncLimiter] = {} + + +def _load_limiter(name: str, default_config: tuple[float, float]) -> AsyncLimiter: + if (o := _limiters.get(name)) is not None: + return o + path = Path(tempfile.gettempdir(), f"bai.container_registry.asynclimiter.{name}") + try: + with open(path, "rb") as f: + o = cast(AsyncLimiter, pickle.load(f)) + if o.max_rate != default_config[0] or o.time_period != default_config[1]: + o = AsyncLimiter(max_rate=default_config[0], time_period=default_config[1]) + print(f"Recreated async limiter from pickled file: {name}", flush=True) + else: + print(f"Loaded async limiter from pickled file: {name}", flush=True) + except (OSError, pickle.PickleError): + o = AsyncLimiter(max_rate=default_config[0], time_period=default_config[1]) + _limiters[name] = o + return o + + +def _save_limiter(name: str, config: AsyncLimiter) -> None: + path = Path(tempfile.gettempdir(), f"bai.container_registry.asynclimiter.{name}") + try: + with open(path, "wb") as f: + pickle.dump(config, f) + print(f"Saved async limiter to pickled file: {name}", flush=True) + except OSError: + log.debug("Failed to store async limiter ({}) status in {}", name, path) + class BaseContainerRegistry(metaclass=ABCMeta): + default_rate_limit: ClassVar = (200, 30) + manifest_rate_limit: ClassVar = (200, 30) + etcd: AsyncEtcd registry_name: str registry_info: Mapping[str, Any] @@ -65,6 +111,20 @@ def __init__( self.sema = ContextVar('sema') self.reporter = ContextVar('reporter', default=None) self.all_updates = ContextVar('all_updates') + self._limiter_prefix = f"{hash(self.registry_url):x}" + # TODO: Use a per-registry global lock to store/load rate limiter states. + self.default_rate_limiter = AsyncLimiter(*type(self).default_rate_limit) + self.manifest_rate_limiter = AsyncLimiter(*type(self).manifest_rate_limit) + # self.default_rate_limiter = _load_limiter( + # f"{self._limiter_prefix}.default", + # type(self).default_rate_limit, + # ) + # self.manifest_rate_limiter = _load_limiter( + # f"{self._limiter_prefix}.manifest", + # type(self).manifest_rate_limit, + # ) + # atexit.register(_save_limiter, f"{self._limiter_prefix}.default", self.default_rate_limiter) + # atexit.register(_save_limiter, f"{self._limiter_prefix}.manifest", self.manifest_rate_limiter) async def rescan_single_registry( self, @@ -118,7 +178,10 @@ async def _scan_image( {'n': '10'}, ) while tag_list_url is not None: - async with sess.get(tag_list_url, **rqst_args) as resp: + async with ( + self.default_rate_limiter, + sess.get(tag_list_url, **rqst_args) as resp, + ): data = json.loads(await resp.read()) if 'tags' in data: # sometimes there are dangling image names in the hub. @@ -150,8 +213,13 @@ async def _scan_tag( skip_reason = None try: async with self.sema.get(): - async with sess.get(self.registry_url / f'v2/{image}/manifests/{tag}', - **rqst_args) as resp: + async with ( + self.manifest_rate_limiter, + sess.get( + self.registry_url / f'v2/{image}/manifests/{tag}', + **rqst_args, + ) as resp, + ): if resp.status == 404: # ignore missing tags # (may occur after deleting an image from the docker hub) @@ -163,8 +231,13 @@ async def _scan_tag( size_bytes = (sum(layer['size'] for layer in data['layers']) + data['config']['size']) - async with sess.get(self.registry_url / f'v2/{image}/blobs/{config_digest}', - **rqst_args) as resp: + async with ( + self.default_rate_limiter, + sess.get( + self.registry_url / f'v2/{image}/blobs/{config_digest}', + **rqst_args, + ) as resp, + ): # content-type may not be json... resp.raise_for_status() data = json.loads(await resp.read()) @@ -200,8 +273,10 @@ async def _scan_tag( updates[f'{tag_prefix}/accels'] = accels res_prefix = 'ai.backend.resource.min.' - for k, v in filter(lambda pair: pair[0].startswith(res_prefix), - labels.items()): + for k, v in filter( + lambda pair: pair[0].startswith(res_prefix), + labels.items(), + ): res_key = k[len(res_prefix):] updates[f'{tag_prefix}/resource/{res_key}/min'] = v self.all_updates.get().update(updates) diff --git a/src/ai/backend/manager/container_registry/docker.py b/src/ai/backend/manager/container_registry/docker.py index 03e08cab1..a34e20032 100644 --- a/src/ai/backend/manager/container_registry/docker.py +++ b/src/ai/backend/manager/container_registry/docker.py @@ -17,6 +17,10 @@ class DockerHubRegistry(BaseContainerRegistry): + # Docker Hub's unauthenticated users can make 100 GET requests of manifests. + default_rate_limit = (100, 30) + manifest_rate_limit = (100, 6 * 3600) + async def fetch_repositories( self, sess: aiohttp.ClientSession, @@ -28,7 +32,10 @@ async def fetch_repositories( repo_list_url: Optional[yarl.URL] repo_list_url = hub_url / f'v2/repositories/{username}/' while repo_list_url is not None: - async with sess.get(repo_list_url, params=params) as resp: + async with ( + self.default_rate_limiter, + sess.get(repo_list_url, params=params) as resp, + ): if resp.status == 200: data = await resp.json() for item in data['results']: @@ -70,7 +77,10 @@ async def fetch_repositories( {'n': '30'}, ) while catalog_url is not None: - async with sess.get(catalog_url, **rqst_args) as resp: + async with ( + self.default_rate_limiter, + sess.get(catalog_url, **rqst_args) as resp, + ): if resp.status == 200: data = json.loads(await resp.read()) for item in data['repositories']: diff --git a/src/ai/backend/manager/container_registry/harbor.py b/src/ai/backend/manager/container_registry/harbor.py index 553a7ac4a..6486f2cd1 100644 --- a/src/ai/backend/manager/container_registry/harbor.py +++ b/src/ai/backend/manager/container_registry/harbor.py @@ -31,7 +31,12 @@ async def fetch_repositories( ) project_ids = [] while project_list_url is not None: - async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp: + async with ( + self.default_rate_limiter, + sess.get( + project_list_url, allow_redirects=False, **rqst_args, + ) as resp, + ): projects = await resp.json() for item in projects: if item['name'] in registry_projects: @@ -90,7 +95,10 @@ async def fetch_repositories( {'page_size': '30'}, ) while repo_list_url is not None: - async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: + async with ( + self.default_rate_limiter, + sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp, + ): items = await resp.json() if isinstance(items, dict) and (errors := items.get('errors', [])): raise RuntimeError(f"failed to fetch repositories in project {project_name}",