Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit 07346d6

Browse files
committed
refactor: Clean up BaseFunction inheritance using contextvars
* Previously, to bind the current session with API function classes, we generated new type objects at runtime. - This has confused IDEs and type checkers. - Now type checkers can statically deduce the types for individual API function classes. - TODO: many many type errors are there still... * Now we use contextvars (ai.backend.client.session.api_session) to keep the reference to the current session. - There are no public Session/AsyncSession API changes! - Only the API function classes need to be rewritten. - For synchronous Session, we pass the context to the separate worker thread using copy_context() whenever calling API functions, which is a light-weight operation. * Remove redundant src/ai/backend/client/etcd.py which had been already copied to src/ai/backend/client/func/etcd.py BREAKING-CHANGE: Dropped Python 3.6 support. Now it requires Python 3.7 or higher.
1 parent 2edb087 commit 07346d6

26 files changed

+584
-999
lines changed

.travis.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ stages:
1010

1111
# build matrix for test stage
1212
python:
13-
- "3.6"
1413
- "3.7"
1514
- "3.8"
1615
os:
@@ -53,7 +52,7 @@ jobs:
5352
fast_finish: true
5453
# exclude the duplicate default test stage
5554
exclude:
56-
- python: "3.6"
55+
- python: "3.7"
5756

5857
notifications:
5958
webhooks:

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ norecursedirs = venv virtualenv .git
1313
timeout = 5
1414
markers =
1515
integration: Test cases that require real manager (and agents) to be running on http://localhost:8081.
16+
17+
[mypy]
18+
ignore_missing_imports = true
19+
namespace_packages = true

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def read_src_version():
8282
'Intended Audience :: Developers',
8383
'Programming Language :: Python',
8484
'Programming Language :: Python :: 3',
85-
'Programming Language :: Python :: 3.6',
8685
'Programming Language :: Python :: 3.7',
8786
'Programming Language :: Python :: 3.8',
8887
'Operating System :: POSIX',
@@ -94,7 +93,7 @@ def read_src_version():
9493
],
9594
package_dir={'': 'src'},
9695
packages=find_namespace_packages(where='src', include='ai.backend.*'),
97-
python_requires='>=3.6',
96+
python_requires='>=3.7',
9897
setup_requires=setup_requires,
9998
install_requires=install_requires,
10099
extras_require={

src/ai/backend/client/etcd.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

src/ai/backend/client/func/admin.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,31 @@
11
from typing import Any, Mapping, Optional
22

3-
from .base import api_function
3+
from .base import api_function, BaseFunction
44
from ..request import Request
5+
from ..session import api_session
56

67
__all__ = (
78
'Admin',
89
)
910

1011

11-
class Admin:
12-
'''
12+
class Admin(BaseFunction):
13+
"""
1314
Provides the function interface for making admin GrapQL queries.
1415
1516
.. note::
1617
1718
Depending on the privilege of your API access key, you may or may not
1819
have access to querying/mutating server-side resources of other
1920
users.
20-
'''
21-
22-
session = None
23-
'''The client session instance that this function class is bound to.'''
21+
"""
2422

2523
@api_function
2624
@classmethod
2725
async def query(cls, query: str,
2826
variables: Optional[Mapping[str, Any]] = None,
2927
) -> Any:
30-
'''
28+
"""
3129
Sends the GraphQL query and returns the response.
3230
3331
:param query: The GraphQL query string.
@@ -36,12 +34,12 @@ async def query(cls, query: str,
3634
in the query.
3735
3836
:returns: The object parsed from the response JSON string.
39-
'''
37+
"""
4038
gql_query = {
4139
'query': query,
4240
'variables': variables if variables else {},
4341
}
44-
rqst = Request(cls.session, 'POST', '/admin/graphql')
42+
rqst = Request(api_session.get(), 'POST', '/admin/graphql')
4543
rqst.set_json(gql_query)
4644
async with rqst.fetch() as resp:
4745
return await resp.json()

src/ai/backend/client/func/agent.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import textwrap
22
from typing import Iterable, Sequence
33

4-
from .base import api_function
4+
from .base import api_function, BaseFunction
55
from ..request import Request
6+
from ..session import api_session
67

78
__all__ = (
89
'Agent',
910
'AgentWatcher',
1011
)
1112

1213

13-
class Agent:
14-
'''
14+
class Agent(BaseFunction):
15+
"""
1516
Provides a shortcut of :func:`Admin.query()
1617
<ai.backend.client.admin.Admin.query>` that fetches various agent
1718
information.
@@ -20,10 +21,7 @@ class Agent:
2021
2122
All methods in this function class require your API access key to
2223
have the *admin* privilege.
23-
'''
24-
25-
session = None
26-
'''The client session instance that this function class is bound to.'''
24+
"""
2725

2826
@api_function
2927
@classmethod
@@ -32,7 +30,7 @@ async def list_with_limit(cls,
3230
offset,
3331
status: str = 'ALIVE',
3432
fields: Iterable[str] = None) -> Sequence[dict]:
35-
'''
33+
"""
3634
Fetches the list of agents with the given status with limit and offset for
3735
pagination.
3836
@@ -42,7 +40,7 @@ async def list_with_limit(cls,
4240
status (one of ``'ALIVE'``, ``'TERMINATED'``, ``'LOST'``,
4341
etc.)
4442
:param fields: Additional per-agent query fields to fetch.
45-
'''
43+
"""
4644
if fields is None:
4745
fields = (
4846
'id',
@@ -65,7 +63,7 @@ async def list_with_limit(cls,
6563
'offset': offset,
6664
'status': status,
6765
}
68-
rqst = Request(cls.session, 'POST', '/admin/graphql')
66+
rqst = Request(api_session.get(), 'POST', '/admin/graphql')
6967
rqst.set_json({
7068
'query': q,
7169
'variables': variables,
@@ -81,14 +79,14 @@ async def detail(cls, agent_id: str, fields: Iterable[str] = None) -> Sequence[d
8179
fields = ('id', 'status', 'addr', 'region', 'first_contact',
8280
'cpu_cur_pct', 'mem_cur_bytes',
8381
'available_slots', 'occupied_slots')
84-
query = textwrap.dedent('''\
82+
query = textwrap.dedent("""\
8583
query($agent_id: String!) {
8684
agent(agent_id: $agent_id) {$fields}
8785
}
88-
''')
86+
""")
8987
query = query.replace('$fields', ' '.join(fields))
9088
variables = {'agent_id': agent_id}
91-
rqst = Request(cls.session, 'POST', '/admin/graphql')
89+
rqst = Request(api_session.get(), 'POST', '/admin/graphql')
9290
rqst.set_json({
9391
'query': query,
9492
'variables': variables,
@@ -98,27 +96,24 @@ async def detail(cls, agent_id: str, fields: Iterable[str] = None) -> Sequence[d
9896
return data['agent']
9997

10098

101-
class AgentWatcher:
102-
'''
99+
class AgentWatcher(BaseFunction):
100+
"""
103101
Provides a shortcut of :func:`Admin.query()
104102
<ai.backend.client.admin.Admin.query>` that manipulate agent status.
105103
106104
.. note::
107105
108106
All methods in this function class require you to
109107
have the *superadmin* privilege.
110-
'''
111-
112-
session = None
113-
'''The client session instance that this function class is bound to.'''
108+
"""
114109

115110
@api_function
116111
@classmethod
117112
async def get_status(cls, agent_id: str) -> dict:
118-
'''
113+
"""
119114
Get agent and watcher status.
120-
'''
121-
rqst = Request(cls.session, 'GET', '/resource/watcher')
115+
"""
116+
rqst = Request(api_session.get(), 'GET', '/resource/watcher')
122117
rqst.set_json({'agent_id': agent_id})
123118
async with rqst.fetch() as resp:
124119
data = await resp.json()
@@ -130,10 +125,10 @@ async def get_status(cls, agent_id: str) -> dict:
130125
@api_function
131126
@classmethod
132127
async def agent_start(cls, agent_id: str) -> dict:
133-
'''
128+
"""
134129
Start agent.
135-
'''
136-
rqst = Request(cls.session, 'POST', '/resource/watcher/agent/start')
130+
"""
131+
rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/start')
137132
rqst.set_json({'agent_id': agent_id})
138133
async with rqst.fetch() as resp:
139134
data = await resp.json()
@@ -145,10 +140,10 @@ async def agent_start(cls, agent_id: str) -> dict:
145140
@api_function
146141
@classmethod
147142
async def agent_stop(cls, agent_id: str) -> dict:
148-
'''
143+
"""
149144
Stop agent.
150-
'''
151-
rqst = Request(cls.session, 'POST', '/resource/watcher/agent/stop')
145+
"""
146+
rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/stop')
152147
rqst.set_json({'agent_id': agent_id})
153148
async with rqst.fetch() as resp:
154149
data = await resp.json()
@@ -160,10 +155,10 @@ async def agent_stop(cls, agent_id: str) -> dict:
160155
@api_function
161156
@classmethod
162157
async def agent_restart(cls, agent_id: str) -> dict:
163-
'''
158+
"""
164159
Restart agent.
165-
'''
166-
rqst = Request(cls.session, 'POST', '/resource/watcher/agent/restart')
160+
"""
161+
rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/restart')
167162
rqst.set_json({'agent_id': agent_id})
168163
async with rqst.fetch() as resp:
169164
data = await resp.json()

src/ai/backend/client/func/auth.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
1-
from .base import api_function
1+
from .base import api_function, BaseFunction
22
from ..request import Request
3+
from ..session import api_session
34

45
__all__ = (
56
'Auth',
67
)
78

89

9-
class Auth:
10-
'''
10+
class Auth(BaseFunction):
11+
"""
1112
Provides the function interface for login session management and authorization.
12-
'''
13+
"""
1314

1415
@api_function
1516
@classmethod
1617
async def login(cls, user_id: str, password: str) -> dict:
17-
'''
18+
"""
1819
Log-in into the endpoint with the given user ID and password.
1920
It creates a server-side web session and return
2021
a dictionary with ``"authenticated"`` boolean field and
2122
JSON-encoded raw cookie data.
22-
'''
23-
rqst = Request(cls.session, 'POST', '/server/login')
23+
"""
24+
rqst = Request(api_session.get(), 'POST', '/server/login')
2425
rqst.set_json({
2526
'username': user_id,
2627
'password': password,
@@ -36,11 +37,11 @@ async def login(cls, user_id: str, password: str) -> dict:
3637
@api_function
3738
@classmethod
3839
async def logout(cls) -> None:
39-
'''
40+
"""
4041
Log-out from the endpoint.
4142
It clears the server-side web session.
42-
'''
43-
rqst = Request(cls.session, 'POST', '/server/logout')
43+
"""
44+
rqst = Request(api_session.get(), 'POST', '/server/logout')
4445
async with rqst.fetch() as resp:
4546
resp.raw_response.raise_for_status()
4647

@@ -50,7 +51,7 @@ async def update_password(cls, old_password: str, new_password: str, new_passwor
5051
"""
5152
Update user's password. This API works only for account owner.
5253
"""
53-
rqst = Request(cls.session, 'POST', '/auth/update-password')
54+
rqst = Request(api_session.get(), 'POST', '/auth/update-password')
5455
rqst.set_json({
5556
'old_password': old_password,
5657
'new_password': new_password,

0 commit comments

Comments
 (0)