Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
aiohttp~=3.8.0
aiohttp_cors~=0.7
aiohttp_sse~=2.0
aiolimiter~=1.0
aiomonitor~=0.4.5
aioredis[hiredis]~=2.0
aiotools~=1.4.0
Expand Down
91 changes: 83 additions & 8 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@
from contextvars import ContextVar
import logging
import json
from typing import Any, AsyncIterator, Dict, Mapping, Optional, TYPE_CHECKING, cast
from pathlib import Path
import pickle
import tempfile
from typing import (
Any,
AsyncIterator,
ClassVar,
Dict,
Mapping,
Optional,
TYPE_CHECKING,
cast,
)

import aiohttp
from aiolimiter import AsyncLimiter
import aiotools
import yarl

Expand All @@ -27,9 +40,42 @@

log = BraceStyleAdapter(logging.getLogger(__name__))

_limiters: dict[str, AsyncLimiter] = {}


def _load_limiter(name: str, default_config: tuple[float, float]) -> AsyncLimiter:
if (o := _limiters.get(name)) is not None:
return o
path = Path(tempfile.gettempdir(), f"bai.container_registry.asynclimiter.{name}")
try:
with open(path, "rb") as f:
o = cast(AsyncLimiter, pickle.load(f))
if o.max_rate != default_config[0] or o.time_period != default_config[1]:
o = AsyncLimiter(max_rate=default_config[0], time_period=default_config[1])
print(f"Recreated async limiter from pickled file: {name}", flush=True)
else:
print(f"Loaded async limiter from pickled file: {name}", flush=True)
except (OSError, pickle.PickleError):
o = AsyncLimiter(max_rate=default_config[0], time_period=default_config[1])
_limiters[name] = o
return o


def _save_limiter(name: str, config: AsyncLimiter) -> None:
path = Path(tempfile.gettempdir(), f"bai.container_registry.asynclimiter.{name}")
try:
with open(path, "wb") as f:
pickle.dump(config, f)
print(f"Saved async limiter to pickled file: {name}", flush=True)
except OSError:
log.debug("Failed to store async limiter ({}) status in {}", name, path)


class BaseContainerRegistry(metaclass=ABCMeta):

default_rate_limit: ClassVar = (200, 30)
manifest_rate_limit: ClassVar = (200, 30)

etcd: AsyncEtcd
registry_name: str
registry_info: Mapping[str, Any]
Expand Down Expand Up @@ -65,6 +111,20 @@ def __init__(
self.sema = ContextVar('sema')
self.reporter = ContextVar('reporter', default=None)
self.all_updates = ContextVar('all_updates')
self._limiter_prefix = f"{hash(self.registry_url):x}"
# TODO: Use a per-registry global lock to store/load rate limiter states.
self.default_rate_limiter = AsyncLimiter(*type(self).default_rate_limit)
self.manifest_rate_limiter = AsyncLimiter(*type(self).manifest_rate_limit)
# self.default_rate_limiter = _load_limiter(
# f"{self._limiter_prefix}.default",
# type(self).default_rate_limit,
# )
# self.manifest_rate_limiter = _load_limiter(
# f"{self._limiter_prefix}.manifest",
# type(self).manifest_rate_limit,
# )
# atexit.register(_save_limiter, f"{self._limiter_prefix}.default", self.default_rate_limiter)
# atexit.register(_save_limiter, f"{self._limiter_prefix}.manifest", self.manifest_rate_limiter)

async def rescan_single_registry(
self,
Expand Down Expand Up @@ -118,7 +178,10 @@ async def _scan_image(
{'n': '10'},
)
while tag_list_url is not None:
async with sess.get(tag_list_url, **rqst_args) as resp:
async with (
self.default_rate_limiter,
sess.get(tag_list_url, **rqst_args) as resp,
):
data = json.loads(await resp.read())
if 'tags' in data:
# sometimes there are dangling image names in the hub.
Expand Down Expand Up @@ -150,8 +213,13 @@ async def _scan_tag(
skip_reason = None
try:
async with self.sema.get():
async with sess.get(self.registry_url / f'v2/{image}/manifests/{tag}',
**rqst_args) as resp:
async with (
self.manifest_rate_limiter,
sess.get(
self.registry_url / f'v2/{image}/manifests/{tag}',
**rqst_args,
) as resp,
):
if resp.status == 404:
# ignore missing tags
# (may occur after deleting an image from the docker hub)
Expand All @@ -163,8 +231,13 @@ async def _scan_tag(
size_bytes = (sum(layer['size'] for layer in data['layers']) +
data['config']['size'])

async with sess.get(self.registry_url / f'v2/{image}/blobs/{config_digest}',
**rqst_args) as resp:
async with (
self.default_rate_limiter,
sess.get(
self.registry_url / f'v2/{image}/blobs/{config_digest}',
**rqst_args,
) as resp,
):
# content-type may not be json...
resp.raise_for_status()
data = json.loads(await resp.read())
Expand Down Expand Up @@ -200,8 +273,10 @@ async def _scan_tag(
updates[f'{tag_prefix}/accels'] = accels

res_prefix = 'ai.backend.resource.min.'
for k, v in filter(lambda pair: pair[0].startswith(res_prefix),
labels.items()):
for k, v in filter(
lambda pair: pair[0].startswith(res_prefix),
labels.items(),
):
res_key = k[len(res_prefix):]
updates[f'{tag_prefix}/resource/{res_key}/min'] = v
self.all_updates.get().update(updates)
Expand Down
14 changes: 12 additions & 2 deletions src/ai/backend/manager/container_registry/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

class DockerHubRegistry(BaseContainerRegistry):

# Docker Hub's unauthenticated users can make 100 GET requests of manifests.
default_rate_limit = (100, 30)
manifest_rate_limit = (100, 6 * 3600)

async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
Expand All @@ -28,7 +32,10 @@ async def fetch_repositories(
repo_list_url: Optional[yarl.URL]
repo_list_url = hub_url / f'v2/repositories/{username}/'
while repo_list_url is not None:
async with sess.get(repo_list_url, params=params) as resp:
async with (
self.default_rate_limiter,
sess.get(repo_list_url, params=params) as resp,
):
if resp.status == 200:
data = await resp.json()
for item in data['results']:
Expand Down Expand Up @@ -70,7 +77,10 @@ async def fetch_repositories(
{'n': '30'},
)
while catalog_url is not None:
async with sess.get(catalog_url, **rqst_args) as resp:
async with (
self.default_rate_limiter,
sess.get(catalog_url, **rqst_args) as resp,
):
if resp.status == 200:
data = json.loads(await resp.read())
for item in data['repositories']:
Expand Down
12 changes: 10 additions & 2 deletions src/ai/backend/manager/container_registry/harbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ async def fetch_repositories(
)
project_ids = []
while project_list_url is not None:
async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp:
async with (
self.default_rate_limiter,
sess.get(
project_list_url, allow_redirects=False, **rqst_args,
) as resp,
):
projects = await resp.json()
for item in projects:
if item['name'] in registry_projects:
Expand Down Expand Up @@ -90,7 +95,10 @@ async def fetch_repositories(
{'page_size': '30'},
)
while repo_list_url is not None:
async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp:
async with (
self.default_rate_limiter,
sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp,
):
items = await resp.json()
if isinstance(items, dict) and (errors := items.get('errors', [])):
raise RuntimeError(f"failed to fetch repositories in project {project_name}",
Expand Down