Skip to content

Commit a857078

Browse files
committed
fix: prefer RS256 for JWT validation
1 parent ae3618c commit a857078

File tree

3 files changed

+63
-41
lines changed

3 files changed

+63
-41
lines changed

renku/ui/service/entrypoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from renku.ui.service.logger import service_log
4242
from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET
4343
from renku.ui.service.utils.json_encoder import SvcJSONProvider
44+
from renku.ui.service.utils import jwk_client
4445
from renku.ui.service.views import error_response
4546
from renku.ui.service.views.apispec import apispec_blueprint
4647
from renku.ui.service.views.cache import cache_blueprint
@@ -76,6 +77,8 @@ def create_app(custom_exceptions=True):
7677

7778
app.config["cache"] = cache
7879

80+
app.config["KEYCLOAK_JWK_CLIENT"] = jwk_client()
81+
7982
if not is_test_session_running():
8083
GunicornPrometheusMetrics(app)
8184

renku/ui/service/serializers/headers.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import base64
1818
import binascii
1919
import os
20+
from typing import cast
2021

2122
import jwt
22-
from marshmallow import Schema, ValidationError, fields, post_load, pre_load
23+
from flask import app
24+
from marshmallow import Schema, ValidationError, fields, post_load
2325
from werkzeug.utils import secure_filename
2426

2527
JWT_TOKEN_SECRET = os.getenv("RENKU_JWT_TOKEN_SECRET", "bW9menZ3cnh6cWpkcHVuZ3F5aWJycmJn")
@@ -79,7 +81,7 @@ class RenkuHeaders:
7981

8082
@staticmethod
8183
def decode_token(token):
82-
"""Extract authorization token."""
84+
"""Extract the Gitlab access token form a bearer authorization header value."""
8385
components = token.split(" ")
8486

8587
rfc_compliant = token.lower().startswith("bearer")
@@ -92,45 +94,22 @@ def decode_token(token):
9294

9395
@staticmethod
9496
def decode_user(data):
95-
"""Extract renku user from a JWT."""
96-
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
97+
"""Extract renku user from the Keycloak ID token which is a JWT."""
98+
try:
99+
jwk = cast(jwt.PyJWKClient, app.config["KEYCLOAK_JWK_CLIENT"])
100+
key = jwk.get_signing_key_from_jwt(data)
101+
decoded = jwt.decode(data, key=key, algorithms=["RS256"], audience="renku")
102+
except jwt.PyJWTError:
103+
# NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256
104+
# above fails. We used to need HS256 because a step that is now removed was generating an ID token and
105+
# signing it from data passed in individual header fields.
106+
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
97107
return UserIdentityToken().load(decoded)
98108

99-
@staticmethod
100-
def reset_old_headers(data):
101-
"""Process old version of old headers."""
102-
# TODO: This should be removed once support for them is phased out.
103-
if "renku-user-id" in data:
104-
data.pop("renku-user-id")
105-
106-
if "renku-user-fullname" in data and "renku-user-email" in data:
107-
renku_user = {
108-
"aud": ["renku"],
109-
"name": decode_b64(data.pop("renku-user-fullname")),
110-
"email": decode_b64(data.pop("renku-user-email")),
111-
}
112-
renku_user["sub"] = renku_user["email"]
113-
data["renku-user"] = jwt.encode(renku_user, JWT_TOKEN_SECRET, algorithm="HS256")
114-
115-
return data
116-
117109

118110
class IdentityHeaders(Schema):
119111
"""User identity schema."""
120112

121-
@pre_load
122-
def set_fields(self, data, **kwargs):
123-
"""Set fields for serialization."""
124-
# NOTE: We don't process headers which are not meant for determining identity.
125-
# TODO: Remove old headers support once support for them is phased out.
126-
old_keys = ["renku-user-id", "renku-user-fullname", "renku-user-email"]
127-
expected_keys = old_keys + [field.data_key for field in self.fields.values()]
128-
129-
data = {key.lower(): value for key, value in data.items() if key.lower() in expected_keys}
130-
data = RenkuHeaders.reset_old_headers(data)
131-
132-
return data
133-
134113
@post_load
135114
def set_user(self, data, **kwargs):
136115
"""Extract user object from a JWT."""
@@ -151,12 +130,12 @@ def set_user(self, data, **kwargs):
151130
class RequiredIdentityHeaders(IdentityHeaders):
152131
"""Identity schema for required headers."""
153132

154-
user_token = fields.String(required=True, data_key="renku-user")
155-
auth_token = fields.String(required=True, data_key="authorization")
133+
user_token = fields.String(required=True, data_key="renku-user") # Keycloak ID token
134+
auth_token = fields.String(required=True, data_key="authorization") # Gitlab access token
156135

157136

158137
class OptionalIdentityHeaders(IdentityHeaders):
159138
"""Identity schema for optional headers."""
160139

161-
user_token = fields.String(data_key="renku-user")
162-
auth_token = fields.String(data_key="authorization")
140+
user_token = fields.String(data_key="renku-user") # Keycloak ID token
141+
auth_token = fields.String(data_key="authorization") # Gitlab access token

renku/ui/service/utils/__init__.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
"""Renku service utility functions."""
17-
from typing import Optional, overload
17+
from time import sleep
18+
from typing import Any, Dict, Optional, overload
1819

19-
from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH
20+
import requests
21+
import urllib
22+
from jwt import PyJWKClient
23+
24+
from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL
25+
from renku.ui.service.errors import ProgramInternalError
26+
from renku.ui.service.logger import service_log
27+
from renku.core.util.requests import get
2028

2129

2230
def make_project_path(user, project):
@@ -86,3 +94,35 @@ def normalize_git_url(git_url: Optional[str]) -> Optional[str]:
8694
git_url = git_url[: -len(".git")]
8795

8896
return git_url
97+
98+
99+
def oidc_discovery() -> Dict[str, Any]:
100+
"""Query the OIDC discovery endpoint from Keycloak with retries, parse the result with JSON and it."""
101+
retries = 0
102+
max_retries = 30
103+
sleep_seconds = 2
104+
while True:
105+
retries += 1
106+
try:
107+
res: requests.Response = get(OIDC_URL)
108+
except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e:
109+
if not retries < max_retries:
110+
service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.")
111+
raise e
112+
service_log.info(
113+
f"Failed to get OIDC discovery data from {OIDC_URL}, sleeping for {sleep_seconds} seconds and retrying"
114+
)
115+
sleep(sleep_seconds)
116+
else:
117+
service_log.info(f"Successfully fetched OIDC discovery data from {OIDC_URL}")
118+
return res.json()
119+
120+
121+
def jwk_client() -> PyJWKClient:
122+
"""Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation"""
123+
oidc_data = oidc_discovery()
124+
jwks_uri = oidc_data.get("jwks_uri")
125+
if not jwks_uri:
126+
raise ProgramInternalError(error_message="Could not find JWK URI in the OIDC discovery data")
127+
jwk = PyJWKClient(jwks_uri)
128+
return jwk

0 commit comments

Comments
 (0)