1
+ import asyncio
1
2
import logging
2
3
import ssl
3
4
from collections .abc import Sequence
5
+ from contextlib import suppress
4
6
from dataclasses import dataclass , field
5
7
from pathlib import Path
6
8
from types import TracebackType
13
15
logger = logging .getLogger (__name__ )
14
16
15
17
16
- class KubeClientUnautorized (Exception ):
18
+ class KubeClientUnauthorized (Exception ):
17
19
pass
18
20
19
21
@@ -60,6 +62,7 @@ def __init__(
60
62
self ._token = config .token
61
63
self ._trace_configs = trace_configs
62
64
self ._client : Optional [aiohttp .ClientSession ] = None
65
+ self ._token_updater_task : Optional [asyncio .Task [None ]] = None
63
66
64
67
def _create_ssl_context (self ) -> Optional [ssl .SSLContext ]:
65
68
if self ._config .url .scheme != "https" :
@@ -76,7 +79,7 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
76
79
return ssl_context
77
80
78
81
async def __aenter__ (self ) -> "KubeClient" :
79
- self . _client = await self ._create_http_client ()
82
+ await self ._init ()
80
83
return self
81
84
82
85
async def __aexit__ (
@@ -87,61 +90,64 @@ async def __aexit__(
87
90
) -> None :
88
91
await self .aclose ()
89
92
90
- async def _create_http_client (self ) -> aiohttp . ClientSession :
93
+ async def _init (self ) -> None :
91
94
connector = aiohttp .TCPConnector (
92
95
limit = self ._config .conn_pool_size ,
93
96
force_close = self ._config .conn_force_close ,
94
97
ssl = self ._create_ssl_context (),
95
98
)
96
- if self ._config .auth_type == KubeClientAuthType .TOKEN :
97
- token = self ._token
98
- if not token :
99
- assert self ._config .token_path is not None
100
- token = Path (self ._config .token_path ).read_text ()
101
- headers = {"Authorization" : "Bearer " + token }
102
- else :
103
- headers = {}
99
+ if self ._config .token_path :
100
+ self ._token = Path (self ._config .token_path ).read_text ()
101
+ self ._token_updater_task = asyncio .create_task (self ._start_token_updater ())
104
102
timeout = aiohttp .ClientTimeout (
105
103
connect = self ._config .conn_timeout_s , total = self ._config .read_timeout_s
106
104
)
107
- return aiohttp .ClientSession (
105
+ self . _client = aiohttp .ClientSession (
108
106
connector = connector ,
109
107
timeout = timeout ,
110
- headers = headers ,
111
108
trace_configs = self ._trace_configs ,
112
109
)
113
110
114
- async def _reload_http_client (self ) -> None :
115
- if self ._client :
116
- await self ._client .close ()
117
- self ._token = None
118
- self ._client = await self ._create_http_client ()
119
-
120
- async def init_if_needed (self ) -> None :
121
- if not self ._client or self ._client .closed :
122
- await self ._reload_http_client ()
111
+ async def _start_token_updater (self ) -> None :
112
+ if not self ._config .token_path :
113
+ return
114
+ while True :
115
+ try :
116
+ token = Path (self ._config .token_path ).read_text ()
117
+ if token != self ._token :
118
+ self ._token = token
119
+ logger .info ("Kube token was refreshed" )
120
+ except asyncio .CancelledError :
121
+ raise
122
+ except Exception as exc :
123
+ logger .exception ("Failed to update kube token: %s" , exc )
124
+ await asyncio .sleep (self ._config .token_update_interval_s )
123
125
124
126
async def aclose (self ) -> None :
125
- assert self ._client
126
- await self ._client .close ()
127
-
128
- async def request (self , * args : Any , ** kwargs : Any ) -> dict [str , Any ]:
129
- await self .init_if_needed ()
130
- assert self ._client , "client is not intialized"
131
- doing_retry = kwargs .pop ("doing_retry" , False )
132
-
133
- async with self ._client .request (* args , ** kwargs ) as resp :
127
+ if self ._client :
128
+ await self ._client .close ()
129
+ self ._client = None
130
+ if self ._token_updater_task :
131
+ self ._token_updater_task .cancel ()
132
+ with suppress (asyncio .CancelledError ):
133
+ await self ._token_updater_task
134
+ self ._token_updater_task = None
135
+
136
+ def _create_headers (
137
+ self , headers : Optional [dict [str , Any ]] = None
138
+ ) -> dict [str , Any ]:
139
+ headers = dict (headers ) if headers else {}
140
+ if self ._config .auth_type == KubeClientAuthType .TOKEN and self ._token :
141
+ headers ["Authorization" ] = "Bearer " + self ._token
142
+ return headers
143
+
144
+ async def _request (self , * args : Any , ** kwargs : Any ) -> dict [str , Any ]:
145
+ headers = self ._create_headers (kwargs .pop ("headers" , None ))
146
+ assert self ._client , "client is not initialized"
147
+ async with self ._client .request (* args , headers = headers , ** kwargs ) as resp :
134
148
resp_payload = await resp .json ()
135
- try :
136
149
self ._raise_for_status (resp_payload )
137
150
return resp_payload
138
- except KubeClientUnautorized :
139
- if doing_retry :
140
- raise
141
- # K8s SA's token might be stale, need to refresh it and retry
142
- await self ._reload_http_client ()
143
- kwargs ["doing_retry" ] = True
144
- return await self .request (* args , ** kwargs )
145
151
146
152
def _raise_for_status (self , payload : dict [str , Any ]) -> None :
147
153
kind = payload ["kind" ]
@@ -150,18 +156,18 @@ def _raise_for_status(self, payload: dict[str, Any]) -> None:
150
156
return
151
157
code = payload .get ("code" )
152
158
if code == 401 :
153
- raise KubeClientUnautorized (payload )
159
+ raise KubeClientUnauthorized (payload )
154
160
raise KubeClientException (payload )
155
161
156
162
async def get_nodes (self ) -> Sequence [Node ]:
157
- payload = await self .request (
163
+ payload = await self ._request (
158
164
method = "get" , url = self ._config .url / "api/v1/nodes"
159
165
)
160
166
assert payload ["kind" ] == "NodeList"
161
167
return [Node .from_payload (p ) for p in payload ["items" ]]
162
168
163
169
async def get_node (self , name : str ) -> Node :
164
- payload = await self .request (
170
+ payload = await self ._request (
165
171
method = "get" , url = self ._config .url / "api/v1/nodes" / name
166
172
)
167
173
assert payload ["kind" ] == "Node"
0 commit comments