Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added common/__init__.py
Empty file.
Empty file added common/auth/__init__.py
Empty file.
1 change: 1 addition & 0 deletions common/auth/cognito_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .backend import JSONWebTokenAuthentication # noqa
103 changes: 103 additions & 0 deletions common/auth/cognito_jwt/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from django.apps import apps as django_apps
from django.conf import settings
from django.utils.encoding import force_str
from django.utils.module_loading import import_string
from django.utils.translation import gettext as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authentication import BaseAuthentication

from .validator import TokenError, TokenValidator

# 2 objects expected when parsing Auth Header: 'Bearer' + token
VALID_AUTH_HEADER_LENGTH = 2


def get_authorization_header(request):
"""
Return request's 'X-UHD-AUTH:' header, as a bytestring.

Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get("HTTP_X_UHD_AUTH", b"")
if isinstance(auth, str):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
return auth


class JSONWebTokenAuthentication(BaseAuthentication):
"""Token based authentication using the JSON Web Token standard.
Based on https://github.com/labd/django-cognito-jwt and modified
to suit our use case
"""

def authenticate(self, request):
"""Entrypoint for Django Rest Framework"""
jwt_token = self.get_jwt_token(request)
if jwt_token is None:
return None

# Authenticate token
try:
token_validator = self.get_token_validator(request)
jwt_payload = token_validator.validate(jwt_token)
except TokenError:
raise exceptions.AuthenticationFailed from None

custom_user_manager = self.get_custom_user_manager()
if custom_user_manager:
user = custom_user_manager.get_or_create_for_cognito(jwt_payload)
else:
user_model = self.get_user_model()
user = user_model.objects.get_or_create_for_cognito(jwt_payload)
return (user, jwt_token)

@staticmethod
def get_custom_user_manager():
"""If COGNITO_USER_MANAGER is set, then the user object is obtained
via get_or_create_for_cognito on the user manager, this allows use
of the default unmodified Django User model"""
result = None
custom_user_manager_path = getattr(settings, "COGNITO_USER_MANAGER", False)
if custom_user_manager_path:
result = import_string(custom_user_manager_path)()
return result

@staticmethod
def get_user_model():
user_model = getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL)
return django_apps.get_model(user_model, require_ready=False)

@staticmethod
def get_jwt_token(request):
auth = get_authorization_header(request).split()
if not auth or force_str(auth[0].lower()) != "bearer":
return None

if len(auth) == 1:
msg = _("Invalid Authorization header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
if len(auth) > VALID_AUTH_HEADER_LENGTH:
msg = _(
"Invalid Authorization header. Credentials string "
"should not contain spaces."
)
raise exceptions.AuthenticationFailed(msg)

return auth[1]

@staticmethod
def get_token_validator(request):
return TokenValidator(
settings.COGNITO_AWS_REGION,
settings.COGNITO_USER_POOL,
settings.COGNITO_AUDIENCE,
)

@staticmethod
def authenticate_header(request):
"""
Method required by the DRF in order to return 401 responses for authentication failures, instead of 403.
More details in https://www.django-rest-framework.org/api-guide/authentication/#custom-authentication.
"""
return "Bearer: api"
26 changes: 26 additions & 0 deletions common/auth/cognito_jwt/user_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

from django.contrib.auth import get_user_model
from django.contrib.auth.models import BaseUserManager, User

logger = logging.getLogger(__name__)


class CognitoManager(BaseUserManager):

@staticmethod
def get_or_create_for_cognito(jwt_payload):
username = jwt_payload["entraObjectId"]
try:
user = get_user_model().objects.get(username=username)
logger.debug("Found existing user %s", user.username)
except User.DoesNotExist:
password = None
user = get_user_model().objects.create_user(
username=username,
password=password,
)
logger.info("Created user %s", user.username)
user.is_active = True
user.save()
return user
88 changes: 88 additions & 0 deletions common/auth/cognito_jwt/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import logging

import jwt
import requests
from django.conf import settings
from django.core.cache import cache
from django.utils.functional import cached_property
from jwt.algorithms import RSAAlgorithm

logger = logging.getLogger(__name__)


class TokenError(Exception):
pass


class TokenValidator:
def __init__(self, aws_region, aws_user_pool, audience):
self.aws_region = aws_region
self.aws_user_pool = aws_user_pool
self.audience = audience

@cached_property
def pool_url(self):
return (
f"https://cognito-idp.{self.aws_region}.amazonaws.com/{self.aws_user_pool}"
)

@cached_property
def _json_web_keys(self):
response = requests.get(self.pool_url + "/.well-known/jwks.json", timeout=10)
response.raise_for_status()
json_data = response.json()
return {item["kid"]: json.dumps(item) for item in json_data["keys"]}

def _get_public_key(self, token):
try:
headers = jwt.get_unverified_header(token)
except jwt.DecodeError as exc:
raise TokenError(str(exc)) from exc

if getattr(settings, "COGNITO_PUBLIC_KEYS_CACHING_ENABLED", False):
cache_key = "cognito_jwt:{}".format(headers["kid"])
jwk_data = cache.get(cache_key)

if not jwk_data:
jwk_data = self._json_web_keys.get(headers["kid"])
timeout = getattr(settings, "COGNITO_PUBLIC_KEYS_CACHING_TIMEOUT", 300)
cache.set(cache_key, jwk_data, timeout=timeout)
else:
jwk_data = self._json_web_keys.get(headers["kid"])

if jwk_data:
return RSAAlgorithm.from_jwk(jwk_data)
return None

def validate(self, token):
public_key = self._get_public_key(token)
if not public_key:
msg = "No key found for this token"
raise TokenError(msg)

params = {
"jwt": token,
"key": public_key,
"issuer": self.pool_url,
"algorithms": ["RS256"],
}

logger.debug("JWT - %s", params)
token_payload = jwt.decode(
token, options={"verify_signature": False} # noqa: S5659
)
logger.debug("JWT decoded - %s", token_payload)

if "aud" in token_payload:
params.update({"audience": self.audience})

try:
jwt_data = jwt.decode(**params)
except (
jwt.InvalidTokenError,
jwt.ExpiredSignatureError,
jwt.DecodeError,
) as exc:
raise TokenError(str(exc)) from exc
return jwt_data
4 changes: 4 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
# The name of the AWS profile to use for the AWS client used for ingestion
AWS_PROFILE_NAME = os.environ.get("AWS_PROFILE_NAME")

# Cognito configuration
COGNITO_AWS_REGION = os.environ.get("COGNITO_AWS_REGION")
COGNITO_USER_POOL = os.environ.get("COGNITO_USER_POOL")

# Database configuration
POSTGRES_DB = os.environ.get("POSTGRES_DB")
POSTGRES_USER = os.environ.get("POSTGRES_USER")
Expand Down
7 changes: 7 additions & 0 deletions metrics/api/settings/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,18 @@
},
]

COGNITO_USER_MANAGER = "common.auth.cognito_jwt.user_manager.CognitoManager"
COGNITO_AWS_REGION = config.COGNITO_AWS_REGION
COGNITO_USER_POOL = config.COGNITO_USER_POOL
COGNITO_AUDIENCE = None
COGNITO_PUBLIC_KEYS_CACHING_ENABLED = True
COGNITO_PUBLIC_KEYS_CACHING_TIMEOUT = 60 * 60 * 24 # 24h caching, default is 300s

REST_FRAMEWORK = {
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
"DEFAULT_AUTHENTICATION_CLASSES": [
"rest_framework.authentication.SessionAuthentication",
"common.auth.cognito_jwt.JSONWebTokenAuthentication",
],
}

Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pytest==9.0.2
pytest-cov==7.1.0
pytest-django==4.12.0
pytest-random-order==1.2.0
pytest-responses==0.5.1
ruff==0.15.7
stevedore==5.7.0
django-debug-toolbar==6.2.0
django-debug-toolbar==6.2.0
2 changes: 2 additions & 0 deletions requirements-prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ click==8.3.1
colorama==0.4.6
coreapi==2.3.3
coreschema==0.0.4
cryptography==46.0.5
defusedxml==0.7.1
distlib==0.4.0
django-cors-headers==4.8.0
Expand Down Expand Up @@ -58,6 +59,7 @@ Pillow==12.1.1
platformdirs==4.9.4
plotly==6.6.0
pluggy==1.6.0
pyjwt==2.12.1
pyparsing==3.3.2
pyrsistent==0.20.0
python-dateutil==2.9.0.post0
Expand Down
100 changes: 100 additions & 0 deletions tests/unit/common/auth/cognito_jwt/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import copy

import pytest
from django.conf import settings


@pytest.fixture(autouse=True)
def cognito_settings(settings):
settings.COGNITO_AWS_REGION = "eu-central-1"
settings.COGNITO_USER_POOL = "bla"
settings.COGNITO_AUDIENCE = "my-client-id"
settings.COGNITO_PUBLIC_KEYS_CACHING_ENABLED = False
settings.CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
"LOCATION": "unique-snowflake",
}
}
settings.ROOT_URLCONF = "urls"


def _private_to_public_key(private_key):
data = copy.deepcopy(private_key)
del data["d"]
return data


@pytest.fixture()
def jwk_private_key_one():
return {
"kty": "RSA",
"d": (
"YKOGWFXP3-wWK1OqrKVoTQ5gjkLJPfn2V2ia1tWZ2Ety20W9fpcQmNuS8U"
"bkl86laVergyup8mE0ZpymxXeNRBYI9MrB_k9DCvpnbxW-S3RN8lT1CxZY"
"oUPK8spaO5V5StMfZFesAbwhVIK_flp1NUynM3BkRZ-rRPaDS1Ynz-Z8ag"
"oFAoz3sf946JitajgIyAJUF8wy8j-heXYdOHXeHebBZPvr5bET8hPxapmG"
"gr2_JpKYQbzJ1Emnn1RlTRqdaUWLLKf-XaiemlB2TLNq5YKg-Cr5yIBfro"
"gjhGwh0yGXbuTXzn0QWR3MYoAU9BxHq9vzl-X1ZcF1GqPqOBPigQ"
),
"e": "AQAB",
"use": "sig",
"kid": "key-one",
"alg": "RS256",
"n": (
"iN7iEEFIhcXYFg0ZxvB_etEwN9-ZgA2-g-WzTpcG2qLKjj2rDr80rGPY7I"
"fXaEDppME9ZcN-Mw8oUxSBUIllMNpE9dA0XUhuklFDDiF02FShj2jwua-A"
"k3ORMIgf2ujGPO-b1rkmEKc6TFu_w5jfum9eocaVVIdqYr2j9mG1UCqI0m"
"d-JuGOZi1_f4hp67Qbve_Bzh_3yvQWsTegFNjp55-MzUX-VZ-IEYqhuzaV"
"70t0rnnqFrYgnPqrwo03MOGHUhSJTyg0vBO4S-FoW0e8YKVU1CIOClCuiB"
"qsjkpRBst1DG9094K_PRFcEszIlwt1NUHDMGQV1gHg3zebXxKumQ"
),
}


@pytest.fixture()
def jwk_public_key_one(jwk_private_key_one):
return _private_to_public_key(jwk_private_key_one)


@pytest.fixture()
def jwk_private_key_two():
return {
"kty": "RSA",
"d": (
"G0-8DUpJmbgnYLVCkKTx481skS7DRS4HZlpwHaqzYZn97tVz9sZ_wJmYK1"
"ejaZ_n2K6474zutmx2_XOXNdJJkxdbmi_HwF7V0Ha3R-kPiOUcL0FMI2vC"
"DOjXN8zQG42GYRq1bcrXRBJbSQQK70SiXesv5v1krB0LLr1P8aQTtQw70h"
"xO1avoeeueKhfHET8tIzVlvXz5s4N0s1fH1C-9Z82vTsqyMo51aBqFjPfB"
"Yc0k-AjrrQsVqmvWAXW-7nTiBRdMkZ8Jes1rNnJWYliGmepZbOBQRqEu-I"
"epvAujPdVSsSnQa1zgRKVOgH4KEGVfVtoNY3HoQGaZ5GhiD5BHgQ"
),
"e": "AQAB",
"use": "sig",
"kid": "key-two",
"alg": "RS256",
"n": (
"hvHv4nocfMqZB6e-paozbjr9MaCqOmOtoiiUEwvBPbXgrBH2-MpkzsV_A7"
"OzcMc1R8UMoLE4k4QedFCwM3HwC8CrasH3qkd0GPJA0py1Toa8w7v5TB5e"
"WmGpi_eBjRQcEyq9xVUE637oIfSmgp3U0QOp4px7FpNw8QhP9eMTUnSo_u"
"vsN-dASz4h1U-fBVktT-9yfPBbjq7BER3OjIuVlRAFrptK8xdG1XZtzxdC"
"6O9CGneDwKDcJS-43PGzjyaz4YIRPBPxysZ0veyKxpD-AcC-qAPf0EWdQG"
"6ik-2wNn-5FIHm01MGNcnh6ntuoyZefA3FRjlvuDrwhz2joE6iqw"
),
}


@pytest.fixture()
def jwk_public_key_two(jwk_private_key_two):
return _private_to_public_key(jwk_private_key_two)


@pytest.fixture()
def cognito_well_known_keys(responses, jwk_public_key_one, jwk_public_key_two):
jwk_keys = {"keys": [jwk_public_key_one]}
responses.add(
responses.GET,
"https://cognito-idp.eu-central-1.amazonaws.com/bla/.well-known/jwks.json",
json=jwk_keys,
status=200,
)
Loading
Loading