diff --git a/AUTHORS b/AUTHORS index 68680e4f9..16c2058b8 100644 --- a/AUTHORS +++ b/AUTHORS @@ -29,6 +29,7 @@ Bart Merenda Bas van Oostveen Brian Helba Carl Schwan +Daniel Golding Daniel 'Vector' Kerr Darrel O'Pry Dave Burkholder diff --git a/CHANGELOG.md b/CHANGELOG.md index fab13a0ea..6d2ea4cca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security --> +## [unreleased] + +### Added +* #1273 Add caching of loading of OIDC private key. + ## [2.3.0] 2023-05-31 ### WARNING diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 3779ed491..d22f7ee82 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -19,6 +19,7 @@ from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend from .settings import oauth2_settings +from .utils import jwk_from_pem from .validators import RedirectURIValidator, WildcardSet @@ -234,7 +235,7 @@ def jwk_key(self): if self.algorithm == AbstractApplication.RS256_ALGORITHM: if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: raise ImproperlyConfigured("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm") - return jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + return jwk_from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY) elif self.algorithm == AbstractApplication.HS256_ALGORITHM: return jwk.JWK(kty="oct", k=base64url_encode(self.client_secret)) raise ImproperlyConfigured("This application does not support signed tokens") diff --git a/oauth2_provider/utils.py b/oauth2_provider/utils.py new file mode 100644 index 000000000..de641f74f --- /dev/null +++ b/oauth2_provider/utils.py @@ -0,0 +1,12 @@ +import functools + +from jwcrypto import jwk + + +@functools.lru_cache() +def jwk_from_pem(pem_string): + """ + A cached version of jwcrypto.JWK.from_pem. + Converting from PEM is expensive for large keys such as those using RSA. + """ + return jwk.JWK.from_pem(pem_string.encode("utf-8")) diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py index d7310c58b..e98630f39 100644 --- a/oauth2_provider/views/oidc.py +++ b/oauth2_provider/views/oidc.py @@ -7,7 +7,7 @@ from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from django.views.generic import FormView, View -from jwcrypto import jwk, jwt +from jwcrypto import jwt from jwcrypto.common import JWException from jwcrypto.jws import InvalidJWSObject from jwcrypto.jwt import JWTExpired @@ -30,6 +30,7 @@ get_refresh_token_model, ) from ..settings import oauth2_settings +from ..utils import jwk_from_pem from .mixins import OAuthLibMixin, OIDCLogoutOnlyMixin, OIDCOnlyMixin @@ -114,7 +115,7 @@ def get(self, request, *args, **kwargs): oauth2_settings.OIDC_RSA_PRIVATE_KEY, *oauth2_settings.OIDC_RSA_PRIVATE_KEYS_INACTIVE, ]: - key = jwk.JWK.from_pem(pem.encode("utf8")) + key = jwk_from_pem(pem) data = {"alg": "RS256", "use": "sig", "kid": key.thumbprint()} data.update(json.loads(key.export_public())) keys.append(data) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..2c319b6ea --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,27 @@ +from oauth2_provider import utils + + +def test_jwk_from_pem_caches_jwk(): + a_tiny_rsa_key = """-----BEGIN RSA PRIVATE KEY----- +MGQCAQACEQCxqYaL6GtPooVMhVwcZrCfAgMBAAECECyNmdsuHvMqIEl9/Fex27kC +CQDlc0deuSVrtQIJAMY4MTw2eCeDAgkA5VzfMykQ5yECCQCgkF4Zl0nHPwIJALPv ++IAFUPv3 +-----END RSA PRIVATE KEY-----""" + + # For the same private key we expect the same object to be returned + + jwk1 = utils.jwk_from_pem(a_tiny_rsa_key) + jwk2 = utils.jwk_from_pem(a_tiny_rsa_key) + + assert jwk1 is jwk2 + + a_different_tiny_rsa_key = """-----BEGIN RSA PRIVATE KEY----- +MGMCAQACEQCvyNNNw4J201yzFVogcfgnAgMBAAECEE3oXe5bNlle+xU4EVHTUIEC +CQDpSvwIvDMSIQIJAMDk47DzG9FHAghtvg1TWpy3oQIJAL6NHlS+RBufAgkA6QLA +2GK4aDc= +-----END RSA PRIVATE KEY-----""" + + # But for a different key, a different object + jwk3 = utils.jwk_from_pem(a_different_tiny_rsa_key) + + assert jwk3 is not jwk1