Skip to content

Commit eae254e

Browse files
feat(core): use oauth 2.0 device auth grant (#2722)
1 parent 713b4a4 commit eae254e

File tree

7 files changed

+178
-106
lines changed

7 files changed

+178
-106
lines changed

renku/command/login.py

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
"""Logging in to a Renku deployment."""
1919

2020
import os
21-
import sys
21+
import time
2222
import urllib
23-
import uuid
2423
import webbrowser
2524
from typing import TYPE_CHECKING
2625

@@ -37,6 +36,8 @@
3736

3837

3938
CONFIG_SECTION = "http"
39+
KEYCLOAK_REALM = "Renku"
40+
CLIENT_ID = "renku-cli"
4041

4142

4243
def login_command():
@@ -67,48 +68,82 @@ def _login(endpoint, git_login, yes, client_dispatcher: IClientDispatcher):
6768
else:
6869
raise errors.ParameterError("Cannot find a unique remote URL for project.")
6970

70-
cli_nonce = str(uuid.uuid4())
71+
auth_server_url = _get_url(
72+
parsed_endpoint, path=f"auth/realms/{KEYCLOAK_REALM}/protocol/openid-connect/auth/device"
73+
)
7174

72-
communication.echo(f"Please log in at {parsed_endpoint.geturl()} on your browser.")
75+
try:
76+
response = requests.post(auth_server_url, data={"client_id": CLIENT_ID})
77+
except errors.RequestError as e:
78+
raise errors.RequestError(f"Cannot connect to authorization server at {auth_server_url}.") from e
7379

74-
login_url = _get_url(parsed_endpoint, "/api/auth/login", cli_nonce=cli_nonce)
75-
webbrowser.open_new_tab(login_url)
80+
requests.check_response(response=response)
81+
data = response.json()
7682

77-
server_nonce = communication.prompt("Once completed, enter the security code that you receive at the end")
78-
cli_token_url = _get_url(parsed_endpoint, "/api/auth/cli-token", cli_nonce=cli_nonce, server_nonce=server_nonce)
83+
verification_uri = data.get("verification_uri")
84+
user_code = data.get("user_code")
85+
verification_uri_complete = f"{verification_uri}?user_code={user_code}"
7986

80-
try:
81-
response = requests.get(cli_token_url)
82-
except errors.RequestError as e:
83-
raise errors.OperationError("Cannot get access token from remote host.") from e
87+
communication.echo(
88+
f"Please grant access to '{CLIENT_ID}' in your browser.\n"
89+
f"If a browser window does not open automatically, go to {verification_uri_complete}"
90+
)
8491

85-
if response.status_code == 200:
86-
access_token = response.json().get("access_token")
87-
_store_token(parsed_endpoint.netloc, access_token)
92+
webbrowser.open_new_tab(verification_uri_complete)
8893

89-
if git_login:
90-
_set_git_credential_helper(repository=client.repository, hostname=parsed_endpoint.netloc)
91-
backup_remote_name, backup_exists, remote = create_backup_remote(
92-
repository=client.repository, remote_name=remote_name, url=remote_url # type:ignore
93-
)
94-
if backup_exists:
95-
communication.echo(f"Backup remote '{backup_remote_name}' already exists. Ignoring '--git' flag.")
96-
elif not remote:
97-
communication.error(f"Cannot create backup remote '{backup_remote_name}' for '{remote_url}'")
94+
polling_interval = min(data.get("interval", 5), 5)
95+
token_url = _get_url(parsed_endpoint, path=f"auth/realms/{KEYCLOAK_REALM}/protocol/openid-connect/token")
96+
device_code = data.get("device_code")
97+
98+
while True:
99+
time.sleep(polling_interval)
100+
101+
response = requests.post(
102+
token_url,
103+
data={
104+
"device_code": device_code,
105+
"client_id": CLIENT_ID,
106+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
107+
},
108+
)
109+
status_code = response.status_code
110+
if status_code == 200:
111+
break
112+
elif status_code == 400:
113+
error = response.json().get("error")
114+
115+
if error == "authorization_pending":
116+
continue
117+
elif error == "slow_down":
118+
polling_interval += 1
119+
elif error == "access_denied":
120+
raise errors.AuthenticationError("Access denied")
121+
elif error == "expired_token":
122+
raise errors.AuthenticationError("Session expired, try again")
98123
else:
99-
_set_renku_url_for_remote(
100-
repository=client.repository,
101-
remote_name=remote_name, # type:ignore
102-
remote_url=remote_url, # type:ignore
103-
hostname=parsed_endpoint.netloc,
104-
)
124+
raise errors.AuthenticationError(f"Invalid error message from server: {response.json()}")
125+
else:
126+
raise errors.AuthenticationError(f"Invalid status code from server: {status_code} - {response.content}")
105127

106-
else:
107-
communication.error(
108-
f"Remote host did not return an access token: {parsed_endpoint.geturl()}, "
109-
f"status code: {response.status_code}"
128+
access_token = response.json().get("access_token")
129+
_store_token(parsed_endpoint.netloc, access_token)
130+
131+
if git_login:
132+
_set_git_credential_helper(repository=client.repository, hostname=parsed_endpoint.netloc)
133+
backup_remote_name, backup_exists, remote = create_backup_remote(
134+
repository=client.repository, remote_name=remote_name, url=remote_url # type:ignore
110135
)
111-
sys.exit(1)
136+
if backup_exists:
137+
communication.echo(f"Backup remote '{backup_remote_name}' already exists. Ignoring '--git' flag.")
138+
elif not remote:
139+
communication.error(f"Cannot create backup remote '{backup_remote_name}' for '{remote_url}'")
140+
else:
141+
_set_renku_url_for_remote(
142+
repository=client.repository,
143+
remote_name=remote_name, # type:ignore
144+
remote_url=remote_url, # type:ignore
145+
hostname=parsed_endpoint.netloc,
146+
)
112147

113148

114149
def _parse_endpoint(endpoint):
@@ -119,7 +154,7 @@ def _parse_endpoint(endpoint):
119154
return parsed_endpoint
120155

121156

122-
def _get_url(parsed_endpoint, path, **query_args):
157+
def _get_url(parsed_endpoint, path, **query_args) -> str:
123158
query = urllib.parse.urlencode(query_args)
124159
return parsed_endpoint._replace(path=path, query=query).geturl()
125160

renku/core/dataset/providers/dataverse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,11 @@ def _post(self, url, json=None, data=None, files=None):
546546

547547
@staticmethod
548548
def _check_response(response):
549-
if response.status_code not in [200, 201, 202]:
550-
if response.status_code == 401:
551-
raise errors.AuthenticationError("Access unauthorized - update access token.")
549+
from renku.core.util import requests
550+
551+
try:
552+
requests.check_response(response=response)
553+
except errors.RequestError:
552554
json_res = response.json()
553555
raise errors.ExportError(
554556
"HTTP {} - Cannot export dataset: {}".format(

renku/core/dataset/providers/olos.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def upload_file(self, full_path, path_in_dataset):
231231

232232
return response
233233

234-
def _make_url(self, server_url, api_path, **query_params):
234+
@staticmethod
235+
def _make_url(server_url, api_path, **query_params):
235236
"""Create URL for creating a dataset."""
236237
url_parts = urlparse.urlparse(server_url)
237238

@@ -259,15 +260,17 @@ def _post(self, url, json=None, data=None, files=None):
259260

260261
@staticmethod
261262
def _check_response(response):
263+
from renku.core.util import requests
264+
262265
if len(response.history) > 0:
263266
raise errors.ExportError(
264267
f"Couldn't execute request to {response.request.url}, got redirected to {response.url}."
265268
"Maybe you mixed up http and https in the server url?"
266269
)
267270

268-
if response.status_code not in [200, 201, 202]:
269-
if response.status_code == 401:
270-
raise errors.AuthenticationError("Access unauthorized - update access token.")
271+
try:
272+
requests.check_response(response=response)
273+
except errors.RequestError:
271274
json_res = response.json()
272275
raise errors.ExportError(
273276
"HTTP {} - Cannot export dataset: {}".format(

renku/core/dataset/providers/zenodo.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
import attr
2929
from tqdm import tqdm
3030

31+
from renku.core import errors
3132
from renku.core.dataset.providers.api import ExporterApi, ProviderApi, ProviderRecordSerializerApi
3233
from renku.core.util.file_size import bytes_to_unit
3334

3435
if TYPE_CHECKING:
3536
from renku.core.dataset.providers.models import ProviderDataset
3637

38+
3739
ZENODO_BASE_URL = "https://zenodo.org"
3840
ZENODO_SANDBOX_URL = "https://sandbox.zenodo.org/"
3941

@@ -358,7 +360,7 @@ def new_deposition(self):
358360
response = requests.post(
359361
url=self.new_deposit_url, params=self.exporter.default_params, json={}, headers=self.exporter.HEADERS
360362
)
361-
requests.check_response(response)
363+
self._check_response(response)
362364

363365
return response
364366

@@ -371,7 +373,7 @@ def upload_file(self, filepath, path_in_repo):
371373
response = requests.post(
372374
url=self.upload_file_url, params=self.exporter.default_params, data=request_payload, files=file
373375
)
374-
requests.check_response(response)
376+
self._check_response(response)
375377

376378
return response
377379

@@ -402,7 +404,7 @@ def attach_metadata(self, dataset, tag):
402404
data=json.dumps(request_payload),
403405
headers=self.exporter.HEADERS,
404406
)
405-
requests.check_response(response)
407+
self._check_response(response)
406408

407409
return response
408410

@@ -411,7 +413,7 @@ def publish_deposition(self, secret):
411413
from renku.core.util import requests
412414

413415
response = requests.post(url=self.publish_url, params=self.exporter.default_params)
414-
requests.check_response(response)
416+
self._check_response(response)
415417

416418
return response
417419

@@ -420,6 +422,25 @@ def __attrs_post_init__(self):
420422
response = self.new_deposition()
421423
self.id = response.json()["id"]
422424

425+
@staticmethod
426+
def _check_response(response):
427+
from renku.core.util import requests
428+
429+
try:
430+
requests.check_response(response=response)
431+
except errors.RequestError:
432+
if response.status_code == 400:
433+
err_response = response.json()
434+
messages = [
435+
'"{0}" failed with "{1}"'.format(err["field"], err["message"]) for err in err_response["errors"]
436+
]
437+
438+
raise errors.ExportError(
439+
"\n" + "\n".join(messages) + "\nSee `renku dataset edit -h` for details on how to edit" " metadata"
440+
)
441+
else:
442+
raise errors.ExportError(response.content)
443+
423444

424445
@attr.s
425446
class ZenodoExporter(ExporterApi):

renku/core/util/requests.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,14 @@ def get_redirect_url(url) -> str:
107107

108108
def check_response(response):
109109
"""Check for expected response status code."""
110-
if response.status_code not in [200, 201, 202]:
111-
if response.status_code == 401:
112-
raise errors.AuthenticationError("Access unauthorized - update access token.")
113-
114-
if response.status_code == 400:
115-
err_response = response.json()
116-
messages = [
117-
'"{0}" failed with "{1}"'.format(err["field"], err["message"]) for err in err_response["errors"]
118-
]
119-
120-
raise errors.ExportError(
121-
"\n" + "\n".join(messages) + "\nSee `renku dataset edit -h` for details on how to edit" " metadata"
122-
)
123-
124-
raise errors.ExportError(response.content)
110+
if response.status_code in [200, 201, 202]:
111+
return
112+
elif response.status_code == 401:
113+
raise errors.AuthenticationError("Access unauthorized - update access token.")
114+
else:
115+
content = response.content.decode("utf-8") if response.content else ""
116+
message = f"Request failed with code {response.status_code}: {content}"
117+
raise errors.RequestError(message)
125118

126119

127120
def download_file(base_directory: Union[Path, str], url: str, filename, extract, chunk_size=16384):

tests/cli/fixtures/cli_gateway.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,73 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
"""Renku CLI fixtures for Gateway."""
19+
1920
import json
21+
import urllib
2022

2123
import pytest
2224
import responses
2325
from _pytest.monkeypatch import MonkeyPatch
2426

2527
ENDPOINT = "renku.deployment.ch"
2628
ACCESS_TOKEN = "jwt-token"
27-
USER_CODE = "valid_user_code"
29+
DEVICE_CODE = "valid-device-code"
2830

2931

3032
@pytest.fixture(scope="module")
3133
def mock_login():
32-
"""Monkey patch webbrowser module for renku login."""
34+
"""Monkey patch webbrowser package and keycloak endpoints for renku login."""
3335
import webbrowser
3436

3537
with MonkeyPatch().context() as monkey_patch:
36-
monkey_patch.setattr(webbrowser, "open_new_tab", lambda _: None)
38+
monkey_patch.setattr(webbrowser, "open_new_tab", lambda _: True)
3739

3840
with responses.RequestsMock(assert_all_requests_are_fired=False) as requests_mock:
3941

40-
def callback(token):
41-
def func(request):
42-
if request.params.get("server_nonce") == USER_CODE:
42+
def device_callback(request):
43+
data = dict(urllib.parse.parse_qsl(request.body))
44+
if data.get("client_id") != "renku-cli":
45+
return 400, {"Content-Type": "application/json"}, json.dumps({"error": "invalid_client"})
46+
47+
data = {
48+
"verification_uri": f"https://{ENDPOINT}/auth/realms/Renku/device",
49+
"user_code": "ABC-DEF",
50+
"interval": 0,
51+
"device_code": DEVICE_CODE,
52+
}
53+
return 200, {"Content-Type": "application/json"}, json.dumps(data)
54+
55+
def create_token_callback(token):
56+
def token_callback(request):
57+
data = dict(urllib.parse.parse_qsl(request.body))
58+
if (
59+
data.get("device_code") == DEVICE_CODE
60+
and data.get("client_id") == "renku-cli"
61+
and data.get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code"
62+
):
4363
return 200, {"Content-Type": "application/json"}, json.dumps({"access_token": token})
4464

45-
return 404, {"Content-Type": "application/json"}, ""
65+
return 400, {"Content-Type": "application/json"}, ""
4666

47-
return func
67+
return token_callback
4868

4969
requests_mock.add_passthru("https://pypi.org/")
5070

5171
class RequestMockWrapper:
5272
@staticmethod
53-
def add_endpoint_token(endpoint, token):
54-
"""Add a mocked endpoint and its access token."""
73+
def add_device_auth(endpoint, token):
74+
"""Add a mocked endpoint."""
75+
requests_mock.add_callback(
76+
responses.POST,
77+
f"https://{endpoint}/auth/realms/Renku/protocol/openid-connect/auth/device",
78+
callback=device_callback,
79+
)
5580
requests_mock.add_callback(
56-
responses.GET, f"https://{endpoint}/api/auth/cli-token", callback=callback(token)
81+
responses.POST,
82+
f"https://{endpoint}/auth/realms/Renku/protocol/openid-connect/token",
83+
callback=create_token_callback(token),
5784
)
5885

59-
RequestMockWrapper.add_endpoint_token(ENDPOINT, ACCESS_TOKEN)
86+
RequestMockWrapper.add_device_auth(ENDPOINT, ACCESS_TOKEN)
6087

6188
yield RequestMockWrapper

0 commit comments

Comments
 (0)