Skip to content

Commit cdcba2f

Browse files
fg91Fabio Grätz
andauthored
Feat: Enable flytekit to authenticate with proxy in front of FlyteAdmin (#1787)
* Introduce authenticator engine and make proxy auth work Signed-off-by: Fabio Grätz <[email protected]> * Use proxy authed session for client credentials flow Signed-off-by: Fabio Grätz <[email protected]> * Don't use authenticator engine but do proxy authentication via existing external command authenticator Signed-off-by: Fabio Grätz <[email protected]> * Add docstring to AuthenticationHTTPAdapter Signed-off-by: Fabio Grätz <[email protected]> * Address todo in docstring Signed-off-by: Fabio Grätz <[email protected]> * Create blank session if none provided Signed-off-by: Fabio Grätz <[email protected]> * Create blank session if none provided in get_token Signed-off-by: Fabio Grätz <[email protected]> * Refresh proxy creds in session when not existing without triggering 401 Signed-off-by: Fabio Grätz <[email protected]> * Add test for get_session Signed-off-by: Fabio Grätz <[email protected]> * Move auth helper test into existing module Signed-off-by: Fabio Grätz <[email protected]> * Move auth helper test into existing module Signed-off-by: Fabio Grätz <[email protected]> * Add test for upgrade_channel_to_proxy_authenticated Signed-off-by: Fabio Grätz <[email protected]> * Auth helper tests without use of responses package Signed-off-by: Fabio Grätz <[email protected]> * Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795) * Add external command plugin to generate id tokens for identity aware proxy Signed-off-by: Fabio Grätz <[email protected]> * Retrieve desktop app client secret from gcp secret manager Signed-off-by: Fabio Grätz <[email protected]> * Remove comments Signed-off-by: Fabio Grätz <[email protected]> * Introduce a command group that allows adding a command to generate service account id tokens later Signed-off-by: Fabio Grätz <[email protected]> * Document how to use plugin and deploy Flyte with IAP Signed-off-by: Fabio Grätz <[email protected]> * Minor corrections README.md Signed-off-by: Fabio Grätz <[email protected]> --------- Signed-off-by: Fabio Grätz <[email protected]> Co-authored-by: Fabio Grätz <[email protected]> Signed-off-by: Fabio Grätz <[email protected]> * Use proxy auth'ed session for device code auth flow Signed-off-by: Fabio Grätz <[email protected]> * Fix token client tests Signed-off-by: Fabio Grätz <[email protected]> * Make poll token endpoint test more specific Signed-off-by: Fabio Grätz <[email protected]> * Make test_client_creds_authenticator test work and more specific Signed-off-by: Fabio Grätz <[email protected]> * Make test_client_creds_authenticator_with_custom_scopes test work and more specific Signed-off-by: Fabio Grätz <[email protected]> * Implement subcommand to generate id tokens for service accounts Signed-off-by: Fabio Graetz <[email protected]> * Test id token generation from service accounts Signed-off-by: Fabio Graetz <[email protected]> * Fix plugin requirements Signed-off-by: Fabio Graetz <[email protected]> * Document usage of generate-service-account-id-token subcommand Signed-off-by: Fabio Grätz <[email protected]> * Document alternative ways to obtain service account id tokens Signed-off-by: Fabio Grätz <[email protected]> --------- Signed-off-by: Fabio Grätz <[email protected]> Signed-off-by: Fabio Graetz <[email protected]> Co-authored-by: Fabio Grätz <[email protected]>
1 parent cf165f7 commit cdcba2f

File tree

18 files changed

+1155
-65
lines changed

18 files changed

+1155
-65
lines changed

flytekit/clients/auth/auth_client.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def __init__(
184184
redirect_uri: typing.Optional[str] = None,
185185
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
186186
verify: typing.Optional[typing.Union[bool, str]] = None,
187+
session: typing.Optional[_requests.Session] = None,
188+
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
189+
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
190+
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
191+
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
187192
):
188193
"""
189194
Create new AuthorizationClient
@@ -192,7 +197,9 @@ def __init__(
192197
:param auth_endpoint: str endpoint where auth metadata can be found
193198
:param token_endpoint: str endpoint to retrieve token from
194199
:param scopes: list[str] oauth2 scopes
195-
:param client_id
200+
:param client_id: oauth2 client id
201+
:param redirect_uri: oauth2 redirect uri
202+
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure
196203
:param verify: (optional) Either a boolean, in which case it controls whether we verify
197204
the server's TLS certificate, or a string, in which case it must be a path
198205
to a CA bundle to use. Defaults to ``True``. When set to
@@ -201,6 +208,15 @@ def __init__(
201208
certificates, which will make your application vulnerable to
202209
man-in-the-middle (MitM) attacks. Setting verify to ``False``
203210
may be useful during local development or testing.
211+
:param session: (optional) A custom requests.Session object to use for making HTTP requests.
212+
If not provided, a new Session object will be created.
213+
:param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser
214+
:param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token
215+
:param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token
216+
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to
217+
the parameters sent when exchanging the auth code for the access token. Defaults to False.
218+
Required e.g. for the PKCE flow with flyteadmin.
219+
Not required for e.g. the standard OAuth2 flow on GCP.
204220
"""
205221
self._endpoint = endpoint
206222
self._auth_endpoint = auth_endpoint
@@ -213,15 +229,13 @@ def __init__(
213229
self._client_id = client_id
214230
self._scopes = scopes or []
215231
self._redirect_uri = redirect_uri
216-
self._code_verifier = _generate_code_verifier()
217-
code_challenge = _create_code_challenge(self._code_verifier)
218-
self._code_challenge = code_challenge
219232
state = _generate_state_parameter()
220233
self._state = state
221234
self._verify = verify
222235
self._headers = {"content-type": "application/x-www-form-urlencoded"}
236+
self._session = session or _requests.Session()
223237

224-
self._params = {
238+
self._request_auth_code_params = {
225239
"client_id": client_id, # This must match the Client ID of the OAuth application.
226240
"response_type": "code", # Indicates the authorization code grant
227241
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
@@ -230,10 +244,18 @@ def __init__(
230244
# callback location where the user-agent will be directed to.
231245
"redirect_uri": self._redirect_uri,
232246
"state": state,
233-
"code_challenge": code_challenge,
234-
"code_challenge_method": "S256",
235247
}
236248

249+
if request_auth_code_params:
250+
# Allow adding additional parameters to the request_auth_code_params
251+
self._request_auth_code_params.update(request_auth_code_params)
252+
253+
self._request_access_token_params = request_access_token_params or {}
254+
self._refresh_access_token_params = refresh_access_token_params or {}
255+
256+
if add_request_auth_code_params_to_request_access_token_params:
257+
self._request_access_token_params.update(self._request_auth_code_params)
258+
237259
def __repr__(self):
238260
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"
239261

@@ -249,7 +271,7 @@ def _create_callback_server(self):
249271

250272
def _request_authorization_code(self):
251273
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
252-
query = _urlencode(self._params)
274+
query = _urlencode(self._request_auth_code_params)
253275
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
254276
logging.debug(f"Requesting authorization code through {endpoint}")
255277
_webbrowser.open_new_tab(endpoint)
@@ -262,33 +284,38 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
262284
"refresh_token": "bar",
263285
"token_type": "Bearer"
264286
}
287+
288+
Can additionally contain "expires_in" and "id_token" fields.
265289
"""
266290
response_body = auth_token_resp.json()
267291
refresh_token = None
292+
id_token = None
268293
if "access_token" not in response_body:
269294
raise ValueError('Expected "access_token" in response from oauth server')
270295
if "refresh_token" in response_body:
271296
refresh_token = response_body["refresh_token"]
272297
if "expires_in" in response_body:
273298
expires_in = response_body["expires_in"]
274299
access_token = response_body["access_token"]
300+
if "id_token" in response_body:
301+
id_token = response_body["id_token"]
275302

276-
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)
303+
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token)
277304

278305
def _request_access_token(self, auth_code) -> Credentials:
279306
if self._state != auth_code.state:
280307
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
281-
self._params.update(
282-
{
283-
"code": auth_code.code,
284-
"code_verifier": self._code_verifier,
285-
"grant_type": "authorization_code",
286-
}
287-
)
288308

289-
resp = _requests.post(
309+
params = {
310+
"code": auth_code.code,
311+
"grant_type": "authorization_code",
312+
}
313+
314+
params.update(self._request_access_token_params)
315+
316+
resp = self._session.post(
290317
url=self._token_endpoint,
291-
data=self._params,
318+
data=params,
292319
headers=self._headers,
293320
allow_redirects=False,
294321
verify=self._verify,
@@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials:
332359
if credentials.refresh_token is None:
333360
raise ValueError("no refresh token available with which to refresh authorization credentials")
334361

335-
resp = _requests.post(
362+
data = {
363+
"refresh_token": credentials.refresh_token,
364+
"grant_type": "refresh_token",
365+
"client_id": self._client_id,
366+
}
367+
368+
data.update(self._refresh_access_token_params)
369+
370+
resp = self._session.post(
336371
url=self._token_endpoint,
337-
data={
338-
"grant_type": "refresh_token",
339-
"client_id": self._client_id,
340-
"refresh_token": credentials.refresh_token,
341-
},
372+
data=data,
342373
headers=self._headers,
343374
allow_redirects=False,
344375
verify=self._verify,

flytekit/clients/auth/authenticator.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass
66

77
import click
8+
import requests
89

910
from . import token_client
1011
from .auth_client import AuthorizationClient
@@ -95,16 +96,24 @@ def __init__(
9596
cfg_store: ClientConfigStore,
9697
header_key: typing.Optional[str] = None,
9798
verify: typing.Optional[typing.Union[bool, str]] = None,
99+
session: typing.Optional[requests.Session] = None,
98100
):
99101
"""
100102
Initialize with default creds from KeyStore using the endpoint name
101103
"""
102104
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
103105
self._cfg_store = cfg_store
104106
self._auth_client = None
107+
self._session = session or requests.Session()
105108

106109
def _initialize_auth_client(self):
107110
if not self._auth_client:
111+
112+
from .auth_client import _create_code_challenge, _generate_code_verifier
113+
114+
code_verifier = _generate_code_verifier()
115+
code_challenge = _create_code_challenge(code_verifier)
116+
108117
cfg = self._cfg_store.get_client_config()
109118
self._set_header_key(cfg.header_key)
110119
self._auth_client = AuthorizationClient(
@@ -115,6 +124,16 @@ def _initialize_auth_client(self):
115124
auth_endpoint=cfg.authorization_endpoint,
116125
token_endpoint=cfg.token_endpoint,
117126
verify=self._verify,
127+
session=self._session,
128+
request_auth_code_params={
129+
"code_challenge": code_challenge,
130+
"code_challenge_method": "S256",
131+
},
132+
request_access_token_params={
133+
"code_verifier": code_verifier,
134+
},
135+
refresh_access_token_params={},
136+
add_request_auth_code_params_to_request_access_token_params=True,
118137
)
119138

120139
def refresh_credentials(self):
@@ -176,6 +195,7 @@ def __init__(
176195
http_proxy_url: typing.Optional[str] = None,
177196
verify: typing.Optional[typing.Union[bool, str]] = None,
178197
audience: typing.Optional[str] = None,
198+
session: typing.Optional[requests.Session] = None,
179199
):
180200
if not client_id or not client_secret:
181201
raise ValueError("Client ID and Client SECRET both are required.")
@@ -186,6 +206,7 @@ def __init__(
186206
self._client_id = client_id
187207
self._client_secret = client_secret
188208
self._audience = audience or cfg.audience
209+
self._session = session or requests.Session()
189210
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)
190211

191212
def refresh_credentials(self):
@@ -211,6 +232,7 @@ def refresh_credentials(self):
211232
verify=self._verify,
212233
scopes=scopes,
213234
audience=audience,
235+
session=self._session,
214236
)
215237

216238
logging.info("Retrieved new token, expires in {}".format(expires_in))
@@ -234,6 +256,7 @@ def __init__(
234256
audience: typing.Optional[str] = None,
235257
http_proxy_url: typing.Optional[str] = None,
236258
verify: typing.Optional[typing.Union[bool, str]] = None,
259+
session: typing.Optional[requests.Session] = None,
237260
):
238261
self._audience = audience
239262
cfg = cfg_store.get_client_config()
@@ -245,6 +268,7 @@ def __init__(
245268
raise AuthenticationError(
246269
"Device Authentication is not available on the Flyte backend / authentication server"
247270
)
271+
self._session = session or requests.Session()
248272
super().__init__(
249273
endpoint=endpoint,
250274
header_key=header_key or cfg.header_key,
@@ -255,7 +279,13 @@ def __init__(
255279

256280
def refresh_credentials(self):
257281
resp = token_client.get_device_code(
258-
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
282+
self._device_auth_endpoint,
283+
self._client_id,
284+
self._audience,
285+
self._scope,
286+
self._http_proxy_url,
287+
self._verify,
288+
self._session,
259289
)
260290
text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}"
261291
click.secho(text)

flytekit/clients/auth/keyring.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44

55
import keyring as _keyring
6-
from keyring.errors import NoKeyringError
6+
from keyring.errors import NoKeyringError, PasswordDeleteError
77

88

99
@dataclass
@@ -16,6 +16,7 @@ class Credentials(object):
1616
refresh_token: str = "na"
1717
for_endpoint: str = "flyte-default"
1818
expires_in: typing.Optional[int] = None
19+
id_token: typing.Optional[str] = None
1920

2021

2122
class KeyringStore:
@@ -25,20 +26,28 @@ class KeyringStore:
2526

2627
_access_token_key = "access_token"
2728
_refresh_token_key = "refresh_token"
29+
_id_token_key = "id_token"
2830

2931
@staticmethod
3032
def store(credentials: Credentials) -> Credentials:
3133
try:
32-
_keyring.set_password(
33-
credentials.for_endpoint,
34-
KeyringStore._refresh_token_key,
35-
credentials.refresh_token,
36-
)
34+
if credentials.refresh_token:
35+
_keyring.set_password(
36+
credentials.for_endpoint,
37+
KeyringStore._refresh_token_key,
38+
credentials.refresh_token,
39+
)
3740
_keyring.set_password(
3841
credentials.for_endpoint,
3942
KeyringStore._access_token_key,
4043
credentials.access_token,
4144
)
45+
if credentials.id_token:
46+
_keyring.set_password(
47+
credentials.for_endpoint,
48+
KeyringStore._id_token_key,
49+
credentials.id_token,
50+
)
4251
except NoKeyringError as e:
4352
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
4453
return credentials
@@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
4857
try:
4958
refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
5059
access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
60+
id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key)
5161
except NoKeyringError as e:
5262
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
5363
return None
5464

55-
if not access_token:
65+
if not access_token and not id_token:
5666
return None
57-
return Credentials(access_token, refresh_token, for_endpoint)
67+
return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token)
5868

5969
@staticmethod
6070
def delete(for_endpoint: str):
6171
try:
6272
_keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
6373
_keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
74+
try:
75+
_keyring.delete_password(for_endpoint, KeyringStore._id_token_key)
76+
except PasswordDeleteError as e:
77+
logging.debug(f"Id token not found in key store, not deleting. Error: {e}")
6478
except NoKeyringError as e:
6579
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")

flytekit/clients/auth/token_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_token(
7878
grant_type: GrantType = GrantType.CLIENT_CREDS,
7979
http_proxy_url: typing.Optional[str] = None,
8080
verify: typing.Optional[typing.Union[bool, str]] = None,
81+
session: typing.Optional[requests.Session] = None,
8182
) -> typing.Tuple[str, int]:
8283
"""
8384
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
@@ -103,7 +104,10 @@ def get_token(
103104
body["audience"] = audience
104105

105106
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
106-
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)
107+
108+
if not session:
109+
session = requests.Session()
110+
response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)
107111

108112
if not response.ok:
109113
j = response.json()
@@ -125,6 +129,7 @@ def get_device_code(
125129
scope: typing.Optional[typing.List[str]] = None,
126130
http_proxy_url: typing.Optional[str] = None,
127131
verify: typing.Optional[typing.Union[bool, str]] = None,
132+
session: typing.Optional[requests.Session] = None,
128133
) -> DeviceCodeResponse:
129134
"""
130135
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
@@ -133,7 +138,9 @@ def get_device_code(
133138
_scope = " ".join(scope) if scope is not None else ""
134139
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
135140
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
136-
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
141+
if not session:
142+
session = requests.Session()
143+
resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
137144
if not resp.ok:
138145
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
139146
return DeviceCodeResponse.from_json_response(resp.json())

0 commit comments

Comments
 (0)