1
+ import enum
1
2
import os
2
3
from pathlib import Path
3
4
import random
4
5
import re
5
6
from typing import (
6
- Any ,
7
7
Callable ,
8
8
Iterable ,
9
9
List ,
10
10
Mapping ,
11
11
Optional ,
12
12
Sequence ,
13
13
Tuple ,
14
+ TypeVar ,
14
15
Union ,
15
16
cast ,
16
17
)
28
29
'MAX_INFLIGHT_CHUNKS' ,
29
30
]
30
31
32
+
33
+ class Undefined (enum .Enum ):
34
+ token = object ()
35
+
36
+
31
37
_config = None
32
- _undefined = object ()
38
+ _undefined = Undefined . token
33
39
34
40
API_VERSION = (6 , '20200815' )
35
41
@@ -47,8 +53,19 @@ def parse_api_version(value: str) -> Tuple[int, str]:
47
53
raise ValueError ('Could not parse the given API version string' , value )
48
54
49
55
50
- def get_env (key : str , default : Any = _undefined , * ,
51
- clean : Callable [[str ], Any ] = lambda v : v ):
56
+ T = TypeVar ('T' )
57
+
58
+
59
+ def default_clean (v : str ) -> T :
60
+ return cast (T , v )
61
+
62
+
63
+ def get_env (
64
+ key : str ,
65
+ default : Union [str , Undefined ] = _undefined ,
66
+ * ,
67
+ clean : Callable [[str ], T ] = default_clean ,
68
+ ) -> T :
52
69
"""
53
70
Retrieves a configuration value from the environment variables.
54
71
The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then
@@ -64,14 +81,14 @@ def get_env(key: str, default: Any = _undefined, *,
64
81
:returns: The value processed by the *clean* function.
65
82
"""
66
83
key = key .upper ()
67
- v = os .environ .get ('BACKEND_' + key )
68
- if v is None :
69
- v = os .environ .get ('SORNA_' + key )
70
- if v is None :
84
+ raw = os .environ .get ('BACKEND_' + key )
85
+ if raw is None :
86
+ raw = os .environ .get ('SORNA_' + key )
87
+ if raw is None :
71
88
if default is _undefined :
72
89
raise KeyError (key )
73
- v = default
74
- return clean (v )
90
+ raw = default
91
+ return clean (raw )
75
92
76
93
77
94
def bool_env (v : str ) -> bool :
@@ -86,8 +103,8 @@ def bool_env(v: str) -> bool:
86
103
def _clean_urls (v : Union [URL , str ]) -> List [URL ]:
87
104
if isinstance (v , URL ):
88
105
return [v ]
106
+ urls = []
89
107
if isinstance (v , str ):
90
- urls = []
91
108
for entry in v .split (',' ):
92
109
url = URL (entry )
93
110
if not url .is_absolute ():
@@ -96,12 +113,10 @@ def _clean_urls(v: Union[URL, str]) -> List[URL]:
96
113
return urls
97
114
98
115
99
- def _clean_tokens (v ):
100
- if isinstance (v , str ):
101
- if not v :
102
- return tuple ()
103
- return tuple (v .split (',' ))
104
- return tuple (iter (v ))
116
+ def _clean_tokens (v : str ) -> Tuple [str , ...]:
117
+ if not v :
118
+ return tuple ()
119
+ return tuple (v .split (',' ))
105
120
106
121
107
122
class APIConfig :
@@ -141,21 +156,22 @@ class APIConfig:
141
156
<ai.backend.client.kernel.Kernel.get_or_create>` calls.
142
157
"""
143
158
144
- DEFAULTS : Mapping [str , Any ] = {
159
+ DEFAULTS : Mapping [str , str ] = {
145
160
'endpoint' : 'https://api.backend.ai' ,
146
161
'endpoint_type' : 'api' ,
147
162
'version' : f'v{ API_VERSION [0 ]} .{ API_VERSION [1 ]} ' ,
148
163
'hash_type' : 'sha256' ,
149
164
'domain' : 'default' ,
150
165
'group' : 'default' ,
151
- 'connection_timeout' : 10.0 ,
152
- 'read_timeout' : None ,
166
+ 'connection_timeout' : ' 10.0' ,
167
+ 'read_timeout' : '0' ,
153
168
}
154
169
"""
155
170
The default values for config parameterse settable via environment variables
156
171
xcept the access and secret keys.
157
172
"""
158
173
174
+ _endpoints : List [URL ]
159
175
_group : str
160
176
_hash_type : str
161
177
@@ -179,35 +195,39 @@ def __init__(
179
195
from . import get_user_agent
180
196
self ._endpoints = (
181
197
_clean_urls (endpoint ) if endpoint else
182
- get_env ('ENDPOINT' , self .DEFAULTS ['endpoint' ], clean = _clean_urls ))
198
+ get_env ('ENDPOINT' , self .DEFAULTS ['endpoint' ], clean = _clean_urls )
199
+ )
183
200
random .shuffle (self ._endpoints )
184
- self ._endpoint_type = endpoint_type if endpoint_type is not None \
185
- else get_env ('ENDPOINT_TYPE' , self .DEFAULTS ['endpoint_type' ])
186
- self ._domain = domain if domain is not None else get_env ('DOMAIN' , self .DEFAULTS ['domain' ])
187
- self ._group = group if group is not None else get_env ('GROUP' , self .DEFAULTS ['group' ])
188
- self ._version = version if version is not None else self .DEFAULTS ['version' ]
201
+ self ._endpoint_type = endpoint_type if endpoint_type is not None else \
202
+ get_env ('ENDPOINT_TYPE' , self .DEFAULTS ['endpoint_type' ], clean = str )
203
+ self ._domain = domain if domain is not None else \
204
+ get_env ('DOMAIN' , self .DEFAULTS ['domain' ], clean = str )
205
+ self ._group = group if group is not None else \
206
+ get_env ('GROUP' , self .DEFAULTS ['group' ], clean = str )
207
+ self ._version = version if version is not None else \
208
+ self .DEFAULTS ['version' ]
189
209
self ._user_agent = user_agent if user_agent is not None else get_user_agent ()
190
210
if self ._endpoint_type == 'api' :
191
- self ._access_key = access_key if access_key is not None \
192
- else get_env ('ACCESS_KEY' , '' )
193
- self ._secret_key = secret_key if secret_key is not None \
194
- else get_env ('SECRET_KEY' , '' )
211
+ self ._access_key = access_key if access_key is not None else \
212
+ get_env ('ACCESS_KEY' , '' )
213
+ self ._secret_key = secret_key if secret_key is not None else \
214
+ get_env ('SECRET_KEY' , '' )
195
215
else :
196
216
self ._access_key = 'dummy'
197
217
self ._secret_key = 'dummy'
198
218
self ._hash_type = hash_type .lower () if hash_type is not None else \
199
219
cast (str , self .DEFAULTS ['hash_type' ])
200
220
arg_vfolders = set (vfolder_mounts ) if vfolder_mounts else set ()
201
- env_vfolders = set (get_env ('VFOLDER_MOUNTS' , [] , clean = _clean_tokens ))
221
+ env_vfolders = set (get_env ('VFOLDER_MOUNTS' , '' , clean = _clean_tokens ))
202
222
self ._vfolder_mounts = [* (arg_vfolders | env_vfolders )]
203
223
# prefer the argument flag and fallback to env if the flag is not set.
204
224
self ._skip_sslcert_validation = (skip_sslcert_validation
205
225
if skip_sslcert_validation else
206
226
get_env ('SKIP_SSLCERT_VALIDATION' , 'no' , clean = bool_env ))
207
227
self ._connection_timeout = connection_timeout if connection_timeout else \
208
- get_env ('CONNECTION_TIMEOUT' , self .DEFAULTS ['connection_timeout' ])
228
+ get_env ('CONNECTION_TIMEOUT' , self .DEFAULTS ['connection_timeout' ], clean = float )
209
229
self ._read_timeout = read_timeout if read_timeout else \
210
- get_env ('READ_TIMEOUT' , self .DEFAULTS ['read_timeout' ])
230
+ get_env ('READ_TIMEOUT' , self .DEFAULTS ['read_timeout' ], clean = float )
211
231
self ._announcement_handler = announcement_handler
212
232
213
233
@property
@@ -233,6 +253,9 @@ def rotate_endpoints(self):
233
253
item = self ._endpoints .pop (0 )
234
254
self ._endpoints .append (item )
235
255
256
+ def load_balance_endpoints (self ):
257
+ pass
258
+
236
259
@property
237
260
def endpoint_type (self ) -> str :
238
261
"""
0 commit comments