Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit d481b19

Browse files
authored
Improve type annotations and prepare for future in-session LB (#150)
* refactor: Update type annotations and prepare the skeleton * feat: skeleton for future load balancer implementation
1 parent 1b8e919 commit d481b19

File tree

4 files changed

+133
-33
lines changed

4 files changed

+133
-33
lines changed

changes/150.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type annotations for API configurations and prepare the skeleton for future in-session endpoint load balancing implementation

src/ai/backend/client/config.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
import enum
12
import os
23
from pathlib import Path
34
import random
45
import re
56
from typing import (
6-
Any,
77
Callable,
88
Iterable,
99
List,
1010
Mapping,
1111
Optional,
1212
Sequence,
1313
Tuple,
14+
TypeVar,
1415
Union,
1516
cast,
1617
)
@@ -28,8 +29,13 @@
2829
'MAX_INFLIGHT_CHUNKS',
2930
]
3031

32+
33+
class Undefined(enum.Enum):
34+
token = object()
35+
36+
3137
_config = None
32-
_undefined = object()
38+
_undefined = Undefined.token
3339

3440
API_VERSION = (6, '20200815')
3541

@@ -47,8 +53,19 @@ def parse_api_version(value: str) -> Tuple[int, str]:
4753
raise ValueError('Could not parse the given API version string', value)
4854

4955

50-
def get_env(key: str, default: Any = _undefined, *,
51-
clean: Callable[[str], Any] = lambda v: v):
56+
T = TypeVar('T')
57+
58+
59+
def default_clean(v: str) -> T:
60+
return cast(T, v)
61+
62+
63+
def get_env(
64+
key: str,
65+
default: Union[str, Undefined] = _undefined,
66+
*,
67+
clean: Callable[[str], T] = default_clean,
68+
) -> T:
5269
"""
5370
Retrieves a configuration value from the environment variables.
5471
The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then
@@ -64,14 +81,14 @@ def get_env(key: str, default: Any = _undefined, *,
6481
:returns: The value processed by the *clean* function.
6582
"""
6683
key = key.upper()
67-
v = os.environ.get('BACKEND_' + key)
68-
if v is None:
69-
v = os.environ.get('SORNA_' + key)
70-
if v is None:
84+
raw = os.environ.get('BACKEND_' + key)
85+
if raw is None:
86+
raw = os.environ.get('SORNA_' + key)
87+
if raw is None:
7188
if default is _undefined:
7289
raise KeyError(key)
73-
v = default
74-
return clean(v)
90+
raw = default
91+
return clean(raw)
7592

7693

7794
def bool_env(v: str) -> bool:
@@ -86,8 +103,8 @@ def bool_env(v: str) -> bool:
86103
def _clean_urls(v: Union[URL, str]) -> List[URL]:
87104
if isinstance(v, URL):
88105
return [v]
106+
urls = []
89107
if isinstance(v, str):
90-
urls = []
91108
for entry in v.split(','):
92109
url = URL(entry)
93110
if not url.is_absolute():
@@ -96,12 +113,10 @@ def _clean_urls(v: Union[URL, str]) -> List[URL]:
96113
return urls
97114

98115

99-
def _clean_tokens(v):
100-
if isinstance(v, str):
101-
if not v:
102-
return tuple()
103-
return tuple(v.split(','))
104-
return tuple(iter(v))
116+
def _clean_tokens(v: str) -> Tuple[str, ...]:
117+
if not v:
118+
return tuple()
119+
return tuple(v.split(','))
105120

106121

107122
class APIConfig:
@@ -141,21 +156,22 @@ class APIConfig:
141156
<ai.backend.client.kernel.Kernel.get_or_create>` calls.
142157
"""
143158

144-
DEFAULTS: Mapping[str, Any] = {
159+
DEFAULTS: Mapping[str, str] = {
145160
'endpoint': 'https://api.backend.ai',
146161
'endpoint_type': 'api',
147162
'version': f'v{API_VERSION[0]}.{API_VERSION[1]}',
148163
'hash_type': 'sha256',
149164
'domain': 'default',
150165
'group': 'default',
151-
'connection_timeout': 10.0,
152-
'read_timeout': None,
166+
'connection_timeout': '10.0',
167+
'read_timeout': '0',
153168
}
154169
"""
155170
The default values for config parameterse settable via environment variables
156171
xcept the access and secret keys.
157172
"""
158173

174+
_endpoints: List[URL]
159175
_group: str
160176
_hash_type: str
161177

@@ -179,35 +195,39 @@ def __init__(
179195
from . import get_user_agent
180196
self._endpoints = (
181197
_clean_urls(endpoint) if endpoint else
182-
get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls))
198+
get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls)
199+
)
183200
random.shuffle(self._endpoints)
184-
self._endpoint_type = endpoint_type if endpoint_type is not None \
185-
else get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type'])
186-
self._domain = domain if domain is not None else get_env('DOMAIN', self.DEFAULTS['domain'])
187-
self._group = group if group is not None else get_env('GROUP', self.DEFAULTS['group'])
188-
self._version = version if version is not None else self.DEFAULTS['version']
201+
self._endpoint_type = endpoint_type if endpoint_type is not None else \
202+
get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type'], clean=str)
203+
self._domain = domain if domain is not None else \
204+
get_env('DOMAIN', self.DEFAULTS['domain'], clean=str)
205+
self._group = group if group is not None else \
206+
get_env('GROUP', self.DEFAULTS['group'], clean=str)
207+
self._version = version if version is not None else \
208+
self.DEFAULTS['version']
189209
self._user_agent = user_agent if user_agent is not None else get_user_agent()
190210
if self._endpoint_type == 'api':
191-
self._access_key = access_key if access_key is not None \
192-
else get_env('ACCESS_KEY', '')
193-
self._secret_key = secret_key if secret_key is not None \
194-
else get_env('SECRET_KEY', '')
211+
self._access_key = access_key if access_key is not None else \
212+
get_env('ACCESS_KEY', '')
213+
self._secret_key = secret_key if secret_key is not None else \
214+
get_env('SECRET_KEY', '')
195215
else:
196216
self._access_key = 'dummy'
197217
self._secret_key = 'dummy'
198218
self._hash_type = hash_type.lower() if hash_type is not None else \
199219
cast(str, self.DEFAULTS['hash_type'])
200220
arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set()
201-
env_vfolders = set(get_env('VFOLDER_MOUNTS', [], clean=_clean_tokens))
221+
env_vfolders = set(get_env('VFOLDER_MOUNTS', '', clean=_clean_tokens))
202222
self._vfolder_mounts = [*(arg_vfolders | env_vfolders)]
203223
# prefer the argument flag and fallback to env if the flag is not set.
204224
self._skip_sslcert_validation = (skip_sslcert_validation
205225
if skip_sslcert_validation else
206226
get_env('SKIP_SSLCERT_VALIDATION', 'no', clean=bool_env))
207227
self._connection_timeout = connection_timeout if connection_timeout else \
208-
get_env('CONNECTION_TIMEOUT', self.DEFAULTS['connection_timeout'])
228+
get_env('CONNECTION_TIMEOUT', self.DEFAULTS['connection_timeout'], clean=float)
209229
self._read_timeout = read_timeout if read_timeout else \
210-
get_env('READ_TIMEOUT', self.DEFAULTS['read_timeout'])
230+
get_env('READ_TIMEOUT', self.DEFAULTS['read_timeout'], clean=float)
211231
self._announcement_handler = announcement_handler
212232

213233
@property
@@ -233,6 +253,9 @@ def rotate_endpoints(self):
233253
item = self._endpoints.pop(0)
234254
self._endpoints.append(item)
235255

256+
def load_balance_endpoints(self):
257+
pass
258+
236259
@property
237260
def endpoint_type(self) -> str:
238261
"""
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
from abc import ABCMeta, abstractmethod
4+
from typing import List, Mapping, Tuple, Type
5+
6+
import attr
7+
from yarl import URL
8+
9+
10+
@attr.s(auto_attribs=True, frozen=True)
11+
class LoadBalancerConfig():
12+
name: str
13+
args: Tuple[str, ...]
14+
15+
16+
class LoadBalancer(metaclass=ABCMeta):
17+
18+
@staticmethod
19+
def load(config: LoadBalancerConfig) -> LoadBalancer:
20+
cls = _cls_map[config.name]
21+
return cls(*config.args)
22+
23+
@staticmethod
24+
def clean_config(config: str) -> LoadBalancerConfig:
25+
name, _, raw_args = config.partition(':')
26+
args = raw_args.split(',')
27+
return LoadBalancerConfig(name, tuple(args))
28+
29+
@abstractmethod
30+
def rotate(self, endpoints: List[URL]) -> None:
31+
raise NotImplementedError
32+
33+
34+
class SimpleRRLoadBalancer(LoadBalancer):
35+
"""
36+
Rotates the endpoints upon every request.
37+
"""
38+
39+
def rotate(self, endpoints: List[URL]) -> None:
40+
if len(endpoints) == 1:
41+
return
42+
item = endpoints.pop(0)
43+
endpoints.append(item)
44+
45+
46+
class PeriodicRRLoadBalancer(LoadBalancer):
47+
"""
48+
Rotates the endpoints upon the specified interval.
49+
"""
50+
51+
def rotate(self, endpoints: List[URL]) -> None:
52+
pass
53+
54+
55+
class LowestLatencyLoadBalancer(LoadBalancer):
56+
"""
57+
Change the endpoints with the lowest average latency for last N requests.
58+
"""
59+
60+
def rotate(self, endpoints: List[URL]) -> None:
61+
pass
62+
63+
# TODO: we need to collect and allow access to the latency statistics.
64+
65+
66+
_cls_map: Mapping[str, Type[LoadBalancer]] = {
67+
'simple_rr': SimpleRRLoadBalancer,
68+
'periodic_rr': PeriodicRRLoadBalancer,
69+
'lowest_latency': LowestLatencyLoadBalancer,
70+
}

src/ai/backend/client/request.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,8 @@ async def __aenter__(self) -> Response:
579579
'\u279c {!r}'.format(e)
580580
await raw_resp.__aexit__(*sys.exc_info())
581581
raise BackendClientError(msg) from e
582+
finally:
583+
self.session.config.load_balance_endpoints()
582584

583585
async def __aexit__(self, *exc_info) -> Optional[bool]:
584586
assert self._rqst_ctx is not None
@@ -716,6 +718,8 @@ async def __aenter__(self) -> WebSocketResponse:
716718
raise BackendClientError(msg) from e
717719
else:
718720
break
721+
finally:
722+
self.session.config.load_balance_endpoints()
719723

720724
wrapped_ws = self.response_cls(self.session, cast(aiohttp.ClientResponse, raw_ws))
721725
if self.on_enter is not None:
@@ -871,6 +875,8 @@ async def __aenter__(self) -> SSEResponse:
871875
msg = 'API endpoint response error.\n' \
872876
'\u279c {!r}'.format(e)
873877
raise BackendClientError(msg) from e
878+
finally:
879+
self.session.config.load_balance_endpoints()
874880

875881
async def __aexit__(self, *args) -> Optional[bool]:
876882
assert self._rqst_ctx is not None

0 commit comments

Comments
 (0)