Skip to content

Commit 99c0f46

Browse files
authored
add update kube token task (#301)
1 parent 23c22b1 commit 99c0f46

File tree

3 files changed

+126
-42
lines changed

3 files changed

+126
-42
lines changed

platform_container_runtime/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class KubeConfig:
2727
client_key_path: Optional[str] = None
2828
token: Optional[str] = None
2929
token_path: Optional[str] = None
30+
token_update_interval_s: int = 300
3031
conn_force_close: bool = False
3132
conn_timeout_s: int = 300
3233
read_timeout_s: int = 100

platform_container_runtime/kube_client.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import logging
23
import ssl
34
from collections.abc import Sequence
5+
from contextlib import suppress
46
from dataclasses import dataclass, field
57
from pathlib import Path
68
from types import TracebackType
@@ -13,7 +15,7 @@
1315
logger = logging.getLogger(__name__)
1416

1517

16-
class KubeClientUnautorized(Exception):
18+
class KubeClientUnauthorized(Exception):
1719
pass
1820

1921

@@ -60,6 +62,7 @@ def __init__(
6062
self._token = config.token
6163
self._trace_configs = trace_configs
6264
self._client: Optional[aiohttp.ClientSession] = None
65+
self._token_updater_task: Optional[asyncio.Task[None]] = None
6366

6467
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
6568
if self._config.url.scheme != "https":
@@ -76,7 +79,7 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
7679
return ssl_context
7780

7881
async def __aenter__(self) -> "KubeClient":
79-
self._client = await self._create_http_client()
82+
await self._init()
8083
return self
8184

8285
async def __aexit__(
@@ -87,61 +90,64 @@ async def __aexit__(
8790
) -> None:
8891
await self.aclose()
8992

90-
async def _create_http_client(self) -> aiohttp.ClientSession:
93+
async def _init(self) -> None:
9194
connector = aiohttp.TCPConnector(
9295
limit=self._config.conn_pool_size,
9396
force_close=self._config.conn_force_close,
9497
ssl=self._create_ssl_context(),
9598
)
96-
if self._config.auth_type == KubeClientAuthType.TOKEN:
97-
token = self._token
98-
if not token:
99-
assert self._config.token_path is not None
100-
token = Path(self._config.token_path).read_text()
101-
headers = {"Authorization": "Bearer " + token}
102-
else:
103-
headers = {}
99+
if self._config.token_path:
100+
self._token = Path(self._config.token_path).read_text()
101+
self._token_updater_task = asyncio.create_task(self._start_token_updater())
104102
timeout = aiohttp.ClientTimeout(
105103
connect=self._config.conn_timeout_s, total=self._config.read_timeout_s
106104
)
107-
return aiohttp.ClientSession(
105+
self._client = aiohttp.ClientSession(
108106
connector=connector,
109107
timeout=timeout,
110-
headers=headers,
111108
trace_configs=self._trace_configs,
112109
)
113110

114-
async def _reload_http_client(self) -> None:
115-
if self._client:
116-
await self._client.close()
117-
self._token = None
118-
self._client = await self._create_http_client()
119-
120-
async def init_if_needed(self) -> None:
121-
if not self._client or self._client.closed:
122-
await self._reload_http_client()
111+
async def _start_token_updater(self) -> None:
112+
if not self._config.token_path:
113+
return
114+
while True:
115+
try:
116+
token = Path(self._config.token_path).read_text()
117+
if token != self._token:
118+
self._token = token
119+
logger.info("Kube token was refreshed")
120+
except asyncio.CancelledError:
121+
raise
122+
except Exception as exc:
123+
logger.exception("Failed to update kube token: %s", exc)
124+
await asyncio.sleep(self._config.token_update_interval_s)
123125

124126
async def aclose(self) -> None:
125-
assert self._client
126-
await self._client.close()
127-
128-
async def request(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
129-
await self.init_if_needed()
130-
assert self._client, "client is not intialized"
131-
doing_retry = kwargs.pop("doing_retry", False)
132-
133-
async with self._client.request(*args, **kwargs) as resp:
127+
if self._client:
128+
await self._client.close()
129+
self._client = None
130+
if self._token_updater_task:
131+
self._token_updater_task.cancel()
132+
with suppress(asyncio.CancelledError):
133+
await self._token_updater_task
134+
self._token_updater_task = None
135+
136+
def _create_headers(
137+
self, headers: Optional[dict[str, Any]] = None
138+
) -> dict[str, Any]:
139+
headers = dict(headers) if headers else {}
140+
if self._config.auth_type == KubeClientAuthType.TOKEN and self._token:
141+
headers["Authorization"] = "Bearer " + self._token
142+
return headers
143+
144+
async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
145+
headers = self._create_headers(kwargs.pop("headers", None))
146+
assert self._client, "client is not initialized"
147+
async with self._client.request(*args, headers=headers, **kwargs) as resp:
134148
resp_payload = await resp.json()
135-
try:
136149
self._raise_for_status(resp_payload)
137150
return resp_payload
138-
except KubeClientUnautorized:
139-
if doing_retry:
140-
raise
141-
# K8s SA's token might be stale, need to refresh it and retry
142-
await self._reload_http_client()
143-
kwargs["doing_retry"] = True
144-
return await self.request(*args, **kwargs)
145151

146152
def _raise_for_status(self, payload: dict[str, Any]) -> None:
147153
kind = payload["kind"]
@@ -150,18 +156,18 @@ def _raise_for_status(self, payload: dict[str, Any]) -> None:
150156
return
151157
code = payload.get("code")
152158
if code == 401:
153-
raise KubeClientUnautorized(payload)
159+
raise KubeClientUnauthorized(payload)
154160
raise KubeClientException(payload)
155161

156162
async def get_nodes(self) -> Sequence[Node]:
157-
payload = await self.request(
163+
payload = await self._request(
158164
method="get", url=self._config.url / "api/v1/nodes"
159165
)
160166
assert payload["kind"] == "NodeList"
161167
return [Node.from_payload(p) for p in payload["items"]]
162168

163169
async def get_node(self, name: str) -> Node:
164-
payload = await self.request(
170+
payload = await self._request(
165171
method="get", url=self._config.url / "api/v1/nodes" / name
166172
)
167173
assert payload["kind"] == "Node"

tests/integration/test_kube_client.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,82 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import os
5+
import tempfile
6+
from collections.abc import AsyncIterator, Iterator
7+
from pathlib import Path
8+
from typing import Any
9+
10+
import aiohttp
11+
import aiohttp.web
12+
import pytest
13+
from yarl import URL
14+
15+
from platform_container_runtime.config import KubeClientAuthType, KubeConfig
116
from platform_container_runtime.kube_client import KubeClient
217

18+
from .conftest import create_local_app_server
19+
20+
21+
class TestKubeClientTokenUpdater:
22+
@pytest.fixture
23+
async def kube_app(self) -> aiohttp.web.Application:
24+
async def _get_nodes(request: aiohttp.web.Request) -> aiohttp.web.Response:
25+
auth = request.headers["Authorization"]
26+
token = auth.split()[-1]
27+
app["token"]["value"] = token
28+
return aiohttp.web.json_response({"kind": "NodeList", "items": []})
29+
30+
app = aiohttp.web.Application()
31+
app["token"] = {"value": ""}
32+
app.router.add_routes([aiohttp.web.get("/api/v1/nodes", _get_nodes)])
33+
return app
34+
35+
@pytest.fixture
36+
async def kube_server(
37+
self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any
38+
) -> AsyncIterator[str]:
39+
async with create_local_app_server(
40+
kube_app, port=unused_tcp_port_factory()
41+
) as address:
42+
yield f"http://{address.host}:{address.port}"
43+
44+
@pytest.fixture
45+
def kube_token_path(self) -> Iterator[str]:
46+
_, path = tempfile.mkstemp()
47+
Path(path).write_text("token-1")
48+
yield path
49+
os.remove(path)
50+
51+
@pytest.fixture
52+
async def kube_client(
53+
self, kube_server: str, kube_token_path: str
54+
) -> AsyncIterator[KubeClient]:
55+
async with KubeClient(
56+
config=KubeConfig(
57+
url=URL(kube_server),
58+
auth_type=KubeClientAuthType.TOKEN,
59+
token_path=kube_token_path,
60+
token_update_interval_s=1,
61+
)
62+
) as client:
63+
yield client
64+
65+
async def test_token_periodically_updated(
66+
self,
67+
kube_app: aiohttp.web.Application,
68+
kube_client: KubeClient,
69+
kube_token_path: str,
70+
) -> None:
71+
await kube_client.get_nodes()
72+
assert kube_app["token"]["value"] == "token-1"
73+
74+
Path(kube_token_path).write_text("token-2")
75+
await asyncio.sleep(2)
76+
77+
await kube_client.get_nodes()
78+
assert kube_app["token"]["value"] == "token-2"
79+
380

481
class TestKubeClient:
582
async def test_get_node(self, kube_client: KubeClient) -> None:

0 commit comments

Comments
 (0)