Skip to content

Commit 932d14f

Browse files
committed
Add service token authentication mechanism
1 parent b61c652 commit 932d14f

File tree

5 files changed

+110
-11
lines changed

5 files changed

+110
-11
lines changed

openshift/template.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ objects:
9595
configMapKeyRef:
9696
name: bayesian-config
9797
key: keycloak-url
98+
- name: BAYESIAN_AUTH_PUBLIC_KEYS_URL
99+
valueFrom:
100+
configMapKeyRef:
101+
name: bayesian-config
102+
key: auth-url
98103
- name: BAYESIAN_JWT_AUDIENCE
99104
value: "fabric8-online-platform,openshiftio-public"
100105
image: "${DOCKER_REGISTRY}/${DOCKER_IMAGE}:${IMAGE_TAG}"

src/auth.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import jwt
66
from os import getenv
77

8-
98
from exceptions import HTTPError
10-
from utils import fetch_public_key
9+
from utils import fetch_public_key, fetch_service_public_keys
1110

1211

13-
def decode_token(token):
12+
def decode_user_token(token):
1413
"""Decode the authorization token read from the request header."""
1514
if token is None:
1615
return {}
@@ -38,6 +37,39 @@ def decode_token(token):
3837
return decoded_token
3938

4039

40+
def decode_service_token(token): # pragma: no cover
41+
"""Decode OSIO service token."""
42+
# TODO: Merge this function and user token function once audience is removed from user tokens.
43+
if token is None:
44+
return {}
45+
46+
if token.startswith('Bearer '):
47+
_, token = token.split(' ', 1)
48+
49+
pub_keys = fetch_service_public_keys(current_app)
50+
decoded_token = None
51+
52+
# Since we have multiple public keys, we need to verify against every public key.
53+
# Token can be decoded by any one of the available public keys.
54+
for pub_key in pub_keys:
55+
try:
56+
pub_key = '-----BEGIN PUBLIC KEY-----\n{pkey}\n-----END PUBLIC KEY-----'\
57+
.format(pkey=pub_key)
58+
decoded_token = jwt.decode(token, pub_key, algorithms=['RS256'])
59+
except jwt.InvalidTokenError:
60+
current_app.logger.error("Auth token couldn't be decoded for public key: {}"
61+
.format(pub_key))
62+
decoded_token = None
63+
64+
if decoded_token:
65+
break
66+
67+
if not decoded_token:
68+
raise jwt.InvalidTokenError('Auth token cannot be verified.')
69+
70+
return decoded_token
71+
72+
4173
def get_token_from_auth_header():
4274
"""Get the authorization token read from the request header."""
4375
return request.headers.get('Authorization')
@@ -62,7 +94,37 @@ def wrapper(*args, **kwargs):
6294
lgr = current_app.logger
6395

6496
try:
65-
decoded = decode_token(get_token_from_auth_header())
97+
decoded = decode_user_token(get_token_from_auth_header())
98+
if not decoded:
99+
lgr.exception('Provide an Authorization token with the API request')
100+
raise HTTPError(401, 'Authentication failed - token missing')
101+
102+
lgr.info('Successfuly authenticated user {e} using JWT'.
103+
format(e=decoded.get('email')))
104+
except jwt.ExpiredSignatureError as exc:
105+
lgr.exception('Expired JWT token')
106+
raise HTTPError(401, 'Authentication failed - token has expired') from exc
107+
except Exception as exc:
108+
lgr.exception('Failed decoding JWT token')
109+
raise HTTPError(401, 'Authentication failed - could not decode JWT token') from exc
110+
111+
return view(*args, **kwargs)
112+
113+
return wrapper
114+
115+
116+
def service_token_required(view): # pragma: no cover
117+
"""Check if the request contains a valid service token."""
118+
@wraps(view)
119+
def wrapper(*args, **kwargs):
120+
# Disable authentication for local setup
121+
if getenv('DISABLE_AUTHENTICATION') in ('1', 'True', 'true'):
122+
return view(*args, **kwargs)
123+
124+
lgr = current_app.logger
125+
126+
try:
127+
decoded = decode_service_token(get_token_from_auth_header())
66128
if not decoded:
67129
lgr.exception('Provide an Authorization token with the API request')
68130
raise HTTPError(401, 'Authentication failed - token missing')

src/rest_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from flask_cors import CORS
55
from utils import DatabaseIngestion, scan_repo, validate_request_data, retrieve_worker_result
66
from f8a_worker.setup_celery import init_selinon
7-
from auth import login_required
7+
from auth import login_required, service_token_required
88
from exceptions import HTTPError
99

1010
app = Flask(__name__)
@@ -215,5 +215,12 @@ def handle_error(e): # pragma: no cover
215215
}), e.status_code
216216

217217

218+
@app.route('/test-service-token')
219+
@service_token_required
220+
def test_service_token(): # pragma: no cover
221+
"""Test the service token authentication mechanism."""
222+
return flask.jsonify({'token': 'is_valid'}), 200
223+
224+
218225
if __name__ == "__main__": # pragma: no cover
219226
app.run()

src/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,28 @@ def fetch_public_key(app):
292292
app.public_key = None
293293

294294
return app.public_key
295+
296+
297+
def fetch_service_public_keys(app): # pragma: no cover
298+
"""Get public keys for OSIO service account. Currently, there are three public keys."""
299+
if not getattr(app, "service_public_keys", []):
300+
auth_url = os.getenv('BAYESIAN_AUTH_PUBLIC_KEYS_URL', '')
301+
if auth_url:
302+
try:
303+
auth_url = auth_url.strip('/') + '/api/token/keys?format=pem'
304+
result = requests.get(auth_url, timeout=0.5)
305+
app.logger.info('Fetching public key from %s, status %d, result: %s',
306+
auth_url, result.status_code, result.text)
307+
except requests.exceptions.Timeout:
308+
app.logger.error('Timeout fetching public key from %s', auth_url)
309+
return ''
310+
if result.status_code != 200:
311+
return ''
312+
313+
keys = result.json().get('keys', [])
314+
app.service_public_keys = keys
315+
316+
else:
317+
app.service_public_keys = None
318+
319+
return app.service_public_keys

tests/test_auth.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,39 +68,39 @@ def mocked_get_audiences_3():
6868
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
6969
def test_decode_token_invalid_input_1(mocked_fetch_public_key, mocked_get_audiences):
7070
"""Test the invalid input handling during token decoding."""
71-
assert decode_token(None) == {}
71+
assert decode_user_token(None) == {}
7272

7373

7474
@patch("auth.get_audiences", side_effect=mocked_get_audiences)
7575
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
7676
def test_decode_token_invalid_input_2(mocked_fetch_public_key, mocked_get_audiences):
7777
"""Test the invalid input handling during token decoding."""
7878
with pytest.raises(Exception):
79-
assert decode_token("Foobar") is None
79+
assert decode_user_token("Foobar") is None
8080

8181

8282
@patch("auth.get_audiences", side_effect=mocked_get_audiences)
8383
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
8484
def test_decode_token_invalid_input_3(mocked_fetch_public_key, mocked_get_audiences):
8585
"""Test the invalid input handling during token decoding."""
8686
with pytest.raises(Exception):
87-
assert decode_token("Bearer ") is None
87+
assert decode_user_token("Bearer ") is None
8888

8989

9090
@patch("auth.get_audiences", side_effect=mocked_get_audiences)
9191
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
9292
def test_decode_token_invalid_input_4(mocked_fetch_public_key, mocked_get_audiences):
9393
"""Test the invalid input handling during token decoding."""
9494
with pytest.raises(Exception):
95-
assert decode_token("Bearer ") is None
95+
assert decode_user_token("Bearer ") is None
9696

9797

9898
@patch("auth.get_audiences", side_effect=mocked_get_audiences_2)
9999
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
100100
def test_decode_token_invalid_input_5(mocked_fetch_public_key, mocked_get_audiences):
101101
"""Test the handling wrong JWT tokens."""
102102
with pytest.raises(Exception):
103-
assert decode_token("Bearer something") is None
103+
assert decode_user_token("Bearer something") is None
104104

105105

106106
@patch("auth.get_audiences", side_effect=mocked_get_audiences_3)
@@ -112,7 +112,7 @@ def test_decode_token_invalid_input_6(mocked_fetch_public_key, mocked_get_audien
112112
'aud': 'foo:bar'
113113
}
114114
token = jwt.encode(payload, PRIVATE_KEY, algorithm='RS256').decode("utf-8")
115-
assert decode_token(token) is not None
115+
assert decode_user_token(token) is not None
116116

117117

118118
def test_audiences():

0 commit comments

Comments
 (0)