diff --git a/CHANGELOG.md b/CHANGELOG.md index 685a564c28..e7b377417a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ All notable changes to this project will be documented in this file. From versio - From now on PostgREST will follow a `MAJOR.PATCH` two-part versioning. Only even-numbered MAJOR versions will be released, reserving odd-numbered MAJOR versions for development. - Replaced `jwt-cache-max-lifetime` config with `jwt-cache-max-entries` by @mkleczek in #4084 - `log-query` config now takes a boolean instead of a string value by @steve-chavez in #3934 +- `jwt-aud` config now takes a regular expression to match against `aud` claim #2099 ## [13.0.8] - 2025-10-24 diff --git a/docs/references/auth.rst b/docs/references/auth.rst index b92bd8c210..81625e214e 100644 --- a/docs/references/auth.rst +++ b/docs/references/auth.rst @@ -193,13 +193,18 @@ PostgREST has built-in validation of the `JWT audience claim ` error. + If the ``aud`` key **is not present** or if its value is ``null`` or ``[]``, PostgREST will interpret this token as allowed for all audiences and will complete the request. +Examples: +- To make PostgREST accept ``aud`` claim value from a set ``audience1``, ``audience2``, ``otheraudience``, :ref:`jwt-aud` claim should be set to ``audience1|audience2|otheraudience``. +- To make PostgREST accept ``aud`` claim value matching any ``https`` URI pointing to a host in ``example.com`` domain, :ref:`jwt-aud` claim should be set to ``https://[a-zA-Z0-9_]*\.example\.com``. +- To make PostgREST accept any ``aud`` claim value , :ref:`jwt-aud` claim should be set to ``.*`` (which is the default). + .. _jwt_caching: JWT Cache diff --git a/docs/references/configuration.rst b/docs/references/configuration.rst index c3b563f1f2..11892904c3 100644 --- a/docs/references/configuration.rst +++ b/docs/references/configuration.rst @@ -596,14 +596,14 @@ jwt-aud ------- =============== ================================= - **Type** String - **Default** `n/a` + **Type** String (must be a valid regular expression) + **Default** `.*` **Reloadable** Y **Environment** PGRST_JWT_AUD **In-Database** pgrst.jwt_aud =============== ================================= - Specifies an audience for the JWT ``aud`` claim. See :ref:`jwt_aud`. + Specifies a regular expression to match against the JWT ``aud`` claim. See :ref:`jwt_aud`. .. _jwt-role-claim-key: diff --git a/nix/tools/generate_targets.py b/nix/tools/generate_targets.py index 762a2c21ab..e7b637e964 100644 --- a/nix/tools/generate_targets.py +++ b/nix/tools/generate_targets.py @@ -33,6 +33,7 @@ def generate_jwt(now: int, exp_inc: Optional[int], is_hs: bool) -> str: payload = { "sub": f"user_{random.getrandbits(32)}", "iat": now, + "aud": "veryveryveryveryverylonglonglonglonglongaudience", "role": "postgrest_test_author", } diff --git a/nix/tools/loadtest.nix b/nix/tools/loadtest.nix index 7ff786efba..6181d1403f 100644 --- a/nix/tools/loadtest.nix +++ b/nix/tools/loadtest.nix @@ -63,6 +63,7 @@ let # load test works across branches # TODO clean once PGRST_JWT_CACHE_MAX_ENTRIES merged and released export PGRST_JWT_CACHE_MAX_LIFETIME="86400" + export PGRST_JWT_AUD="audience|([a-z]er[a-z])*(long)*aud..nce" mkdir -p "$(dirname "$_arg_output")" abs_output="$(realpath "$_arg_output")" diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index 7b5b9bd241..134a79d7c7 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RecordWildCards #-} {-| Module : PostgREST.Auth Description : PostgREST authentication functions. @@ -30,13 +32,19 @@ import System.TimeIt (timeItT) import PostgREST.AppState (AppState, getConfig, getJwtCacheState, getTime) -import PostgREST.Auth.Jwt (parseClaims) import PostgREST.Auth.JwtCache (lookupJwtCache) import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Config (AppConfig (..)) -import PostgREST.Error (Error (..)) +import PostgREST.Config (AppConfig (..), FilterExp (..), + JSPath, JSPathExp (..)) +import PostgREST.Error (Error (..), JwtError (..)) -import Protolude +import Control.Monad.Except (liftEither) +import qualified Data.Aeson as JSON +import qualified Data.Aeson.Key as K +import qualified Data.Aeson.KeyMap as KM +import qualified Data.Text as T +import qualified Data.Vector as V +import Protolude -- | Validate authorization header -- Parse and store JWT claims for future use in the request. @@ -46,7 +54,8 @@ middleware appState app req respond = do time <- getTime appState let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) - parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time + parseToken = maybe (pure KM.empty) (lookupJwtCache jwtCacheState time) + parseJwt = runExceptT $ parseToken >=> parseClaims conf $ token jwtCacheState = getJwtCacheState appState -- If ServerTimingEnabled -> calculate JWT validation time @@ -59,6 +68,38 @@ middleware appState app req respond = do app req' respond +parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> JSON.Object -> m AuthResult +parseClaims AppConfig{configJwtRoleClaimKey, configDbAnonRole} mclaims = do + -- role defaults to anon if not specified in jwt + role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $ + unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole + pure AuthResult + { authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role) + , authRole = role + } + where + walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value + walkJSPath x [] = x + walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest + walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest + walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar + walkJSPath _ _ = Nothing + + findFirstMatch matchWith pattern = foldr checkMatch Nothing + where + checkMatch (JSON.String txt) acc + | pattern `matchWith` txt = Just $ JSON.String txt + | otherwise = acc + checkMatch _ acc = acc + + unquoted :: JSON.Value -> BS.ByteString + unquoted (JSON.String t) = encodeUtf8 t + unquoted v = BS.toStrict $ JSON.encode v + authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey {-# NOINLINE authResultKey #-} diff --git a/src/PostgREST/Auth/Jwt.hs b/src/PostgREST/Auth/Jwt.hs index 734f302ab0..3b83cf6676 100644 --- a/src/PostgREST/Auth/Jwt.hs +++ b/src/PostgREST/Auth/Jwt.hs @@ -4,28 +4,30 @@ Description : PostgREST JWT support functions. This module provides functions to deal with JWT parsing and validation (http://jwt.io). -} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module PostgREST.Auth.Jwt - ( parseAndDecodeClaims - , parseClaims) where - -import qualified Data.Aeson as JSON -import qualified Data.Aeson.Key as K -import qualified Data.Aeson.KeyMap as KM -import qualified Data.ByteString as BS -import qualified Data.ByteString.Internal as BS -import qualified Data.ByteString.Lazy.Char8 as LBS -import qualified Data.Scientific as Sci -import qualified Data.Text as T -import qualified Data.Vector as V -import qualified Jose.Jwk as JWT -import qualified Jose.Jwt as JWT + ( Validation (..) + , Validated (getValidated) + , parseAndDecodeClaims + , validateAud + , validateTimeClaims + , (>>>)) where + +import qualified Data.Aeson as JSON +import qualified Data.Aeson.KeyMap as KM +import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS +import qualified Data.Scientific as Sci +import qualified Jose.Jwk as JWT +import qualified Jose.Jwt as JWT import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) @@ -33,39 +35,77 @@ import Data.Text () import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, - JSPathExp (..), audMatchesCfg) -import PostgREST.Error (Error (..), - JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed), - JwtDecodeError (..), JwtError (..)) +import PostgREST.Error (Error (..), + JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed), + JwtDecodeError (..), JwtError (..)) import Data.Aeson ((.:?)) import Data.Aeson.Types (parseMaybe) +import Data.Coerce (coerce) import Jose.Jwk (JwkSet) import Protolude hiding (first) +-- A value tagged by a type-level list of validations pefrormed on it +newtype Validated (k :: [v]) a = Validated { getValidated :: a } + +-- Helper to implement type safe validation chaining +type family (++) (lst::[k]) lst' where + '[] ++ lst = lst + (l : ls) ++ lst = l : (ls ++ lst) + +-- Validation chaining operator +(>>>) :: (Monad m, Coercible (m (Validated kc c)) (m (Validated (kb ++ kc) c))) + => (a -> m (Validated kb b)) + -> (b -> m (Validated kc c)) + -> a + -> m (Validated (kb ++ kc) c) +f >>> g = coerce . (f >=> g . coerce) + parseAndDecodeClaims :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JSON.Object -parseAndDecodeClaims jwkSet token = parseToken jwkSet token >>= decodeClaims +parseAndDecodeClaims jwkSet = parseToken jwkSet >=> decodeClaims + +data Validation = Aud | Time + +validateAud :: MonadError Error m => (Text -> Bool) -> JSON.Object -> m (Validated '[Aud] JSON.Object) +validateAud = validate . checkAud + +validateTimeClaims :: MonadError Error m => UTCTime -> JSON.Object -> m (Validated '[Time] JSON.Object) +validateTimeClaims = validate . checkExpNbfIat decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims) decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType -validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m () -validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims) +validate :: MonadError Error m => (t -> Alt Maybe JwtClaimsError) -> t -> m (Validated k t) +validate f claims = fmap Validated $ liftEither $ maybeToLeft claims $ fmap JwtErr . getAlt $ JwtClaimsErr <$> f claims data ValidAud = VAString Text | VAArray [Text] deriving Generic instance JSON.FromJSON ValidAud where parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } -checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError -checkForErrors time audMatches = mconcat +claim :: (JSON.FromJSON a, Applicative f, Monoid (f p)) => KM.Key -> p -> (a -> f p) -> JSON.Object -> f p +claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key) + +checkValue :: (Applicative f, Monoid (f p)) => (t -> Bool) -> p -> t -> f p +checkValue invalid msg val = + if invalid val then + pure msg + else + mempty + +checkAud :: (Applicative f, Monoid (f JwtClaimsError)) => (Text -> Bool) -> JSON.Object -> f JwtClaimsError +checkAud audMatches = claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience + where + validAud = \case + (VAString aud) -> audMatches aud + (VAArray auds) -> null auds || any audMatches auds + +checkExpNbfIat :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> JSON.Object -> m JwtClaimsError +checkExpNbfIat time = mconcat [ claim "exp" ExpClaimNotNumber $ inThePast JWTExpired , claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid , claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture - , claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience ] where allowedSkewSeconds = 30 :: Int64 @@ -78,20 +118,6 @@ checkForErrors time audMatches = mconcat checkTime cond = checkValue (cond. sciToInt) - validAud = \case - (VAString aud) -> audMatches aud - (VAArray auds) -> null auds || any audMatches auds - - checkValue invalid msg val = - if invalid val then - pure msg - else - mempty - - claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key) - --- | Receives the JWT secret and audience (from config) and a JWT and returns a --- JSON object of JWT claims. parseToken :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JWT.JwtContent parseToken _ "" = throwError $ JwtErr $ JwtDecodeErr EmptyAuthHeader parseToken secret tkn = do @@ -116,36 +142,3 @@ parseToken secret tkn = do jwtDecodeError JWT.BadCrypto = JwtDecodeErr BadCrypto -- Control never reaches here, the decode function only returns the above three jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError - -parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult -parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do - validateClaims time (audMatchesCfg cfg) mclaims - -- role defaults to anon if not specified in jwt - role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $ - unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole - pure AuthResult - { authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role) - , authRole = role - } - where - walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value - walkJSPath x [] = x - walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest - walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar - walkJSPath _ _ = Nothing - - findFirstMatch matchWith pattern = foldr checkMatch Nothing - where - checkMatch (JSON.String txt) acc - | pattern `matchWith` txt = Just $ JSON.String txt - | otherwise = acc - checkMatch _ acc = acc - - unquoted :: JSON.Value -> BS.ByteString - unquoted (JSON.String t) = encodeUtf8 t - unquoted v = LBS.toStrict $ JSON.encode v diff --git a/src/PostgREST/Auth/JwtCache.hs b/src/PostgREST/Auth/JwtCache.hs index 395ecd48d4..86234b1a68 100644 --- a/src/PostgREST/Auth/JwtCache.hs +++ b/src/PostgREST/Auth/JwtCache.hs @@ -1,15 +1,20 @@ {-| -Module : PostgREST.Auth.JwtCache +Module : PostgREST.Auth.Caching Description : PostgREST JWT validation results Cache. This module provides functions to deal with the JWT cache. -} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} module PostgREST.Auth.JwtCache ( init @@ -18,97 +23,116 @@ module PostgREST.Auth.JwtCache , lookupJwtCache ) where -import qualified Data.Aeson as JSON -import qualified Data.Aeson.KeyMap as KM +import qualified Data.Aeson as JSON import PostgREST.Error (Error (..), JwtError (JwtSecretMissing)) import Control.Concurrent.STM (newTVarIO, readTVar, writeTVar) import Control.Concurrent.STM.TVar (TVar) -import Control.Monad.Error.Class (liftEither) import Data.ByteString hiding (all, init) import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import Data.Time.Clock (UTCTime) import Jose.Jwk (JwkSet) -import PostgREST.Auth.Jwt (parseAndDecodeClaims) +import PostgREST.Auth.Jwt (Validated (getValidated), + Validation (..), + parseAndDecodeClaims, + validateAud, + validateTimeClaims, + (>>>)) import PostgREST.Cache.Sieve (alwaysValid) import qualified PostgREST.Cache.Sieve as SC -import PostgREST.Config (AppConfig (..)) +import PostgREST.Config (AppConfig (..), + audMatchesCfg) import PostgREST.Observation (Observation (JwtCacheEviction, JwtCacheLookup), ObservationHandler) import Protolude -data JwtCacheState = JwtCacheState ObservationHandler (IORef JwtCache) +type Cache m v = SC.Cache m ByteString v -class CacheVariant m v where - cached :: SC.Cache m ByteString v -> ByteString -> ExceptT Error IO JSON.Object +type SelectedCacheVariant = Cache (ExceptT Error IO) (Validated '[Aud] JSON.Object) -{-| -Jwt caching can have three different configurations: -* missing JWT Key (no caching and throw error when JWT token present in the request) -* JWT cache turned off -* JWT cache turned on +data CacheState = + JwksNotConfigured | + NotCaching JwkSet (Text -> Bool) | + Caching JwkSet AppConfig (TVar Int) SelectedCacheVariant -All three options are represented by JwtCache data type. +data JwtCacheState = JwtCacheState ObservationHandler (IORef CacheState) -Handling of reconfiguration is centralized in this module. --} -data JwtCache = - JwtNoJwks | - JwtNoCache JwkSet | - forall m v. CacheVariant m v => JwtCache JwkSet (TVar Int) (SC.Cache m ByteString v) +parseAndValidateAud :: CacheState -> ByteString -> ExceptT Error IO (Validated '[Aud] JSON.Object) +parseAndValidateAud (Caching _ config _ c) = lookup config c +parseAndValidateAud (NotCaching key audMatches) = parseAndDecodeClaims key >=> validateAud audMatches +parseAndValidateAud JwksNotConfigured = const $ throwError (JwtErr JwtSecretMissing) + +class NeedsReinitialize v where + needsReinitialize :: AppConfig -> AppConfig -> Bool + +instance NeedsReinitialize JSON.Object where + needsReinitialize _ _ = False -instance CacheVariant IO (Either Error JSON.Object) where - cached c = lift . SC.cached c >=> liftEither +instance NeedsReinitialize (Validated (Aud : rest) JSON.Object) where + needsReinitialize old new = configJwtAudience old /= configJwtAudience new -instance CacheVariant (ExceptT Error IO) JSON.Object where - cached = SC.cached +instance NeedsReinitialize v => NeedsReinitialize (SC.Cache m k v) where + needsReinitialize = needsReinitialize @v -decode :: JwtCache -> ByteString -> ExceptT Error IO JSON.Object -decode JwtNoJwks = const $ throwError (JwtErr JwtSecretMissing) -decode (JwtNoCache key) = parseAndDecodeClaims key -decode (JwtCache _ _ c) = cached c +class CacheVariant c where + newCache :: STM Int -> JwkSet -> ObservationHandler -> AppConfig -> IO c + lookup :: AppConfig -> c -> ByteString -> ExceptT Error IO (Validated '[Aud] JSON.Object) + +-- Cache parsed JWTs with valid signature and valid aud +instance CacheVariant (Cache (ExceptT Error IO) (Validated '[Aud] JSON.Object)) where + newCache maxSize key observationHandler config = SC.cacheIO (SC.CacheConfig maxSize + (parseAndDecodeClaims key >=> validateAud (audMatchesCfg config)) + (lift . observationHandler . JwtCacheLookup) -- lookup metrics + (const . const $ lift $ observationHandler JwtCacheEviction) -- evictions metrics + alwaysValid) -- no invalidation for now + lookup _ = SC.cached -- | Reconfigure JWT caching and update JwtCacheState accordingly update :: JwtCacheState -> AppConfig -> IO () -update (JwtCacheState observationHandler jwtCacheState) config@AppConfig{configJWKS, configJwtCacheMaxEntries} = - let reinitialize = - newJwtCache config observationHandler - >>= writeIORef jwtCacheState - in - readIORef jwtCacheState >>= \case - (JwtCache decodingKey maxSize _) -> - if configJWKS /= Just decodingKey || configJwtCacheMaxEntries <= 0 then - -- reinitialize if key changed or cache disabled +update (JwtCacheState observationHandler ref) config@AppConfig{configJWKS, configJwtCacheMaxEntries} = + readIORef ref >>= \case + (Caching decodingKey oldConfig maxSize cache) -> + if configJWKS /= Just decodingKey || + configJwtCacheMaxEntries <= 0 || + needsReinitialize @SelectedCacheVariant oldConfig config + then + -- reinitialize if key changed or cache disabled or the cache requires reinitialization reinitialize - else - -- max size changed - set it and let the cache shrink itself if necessary + else do + -- max size changed - set it and let the cache resize itself if necessary atomically $ writeTVar maxSize configJwtCacheMaxEntries - + -- save new config for future updates + writeIORef ref $ Caching decodingKey config maxSize cache _ -> reinitialize + where + reinitialize = newJwtCacheState config observationHandler >>= writeIORef ref init :: AppConfig -> ObservationHandler -> IO JwtCacheState -init config = fmap (<$>) JwtCacheState <*> (newJwtCache config >=> newIORef) +init config = fmap (<$>) JwtCacheState <*> (newJwtCacheState config >=> newIORef) -- | Initialize JwtCacheState -newJwtCache :: AppConfig -> ObservationHandler -> IO JwtCache -newJwtCache AppConfig{configJWKS, configJwtCacheMaxEntries} observationHandler = do - maybe (pure JwtNoJwks) initCache configJWKS +newJwtCacheState :: AppConfig -> ObservationHandler -> IO CacheState +newJwtCacheState config@AppConfig{configJWKS, configJwtCacheMaxEntries} observationHandler = + maybe (pure JwksNotConfigured) initCache configJWKS where - initCache key = if configJwtCacheMaxEntries <= 0 then pure (JwtNoCache key) else createCache key configJwtCacheMaxEntries + initCache key = + if configJwtCacheMaxEntries <= 0 then + pure $ NotCaching key (audMatchesCfg config) + else + createCache key configJwtCacheMaxEntries createCache key maxSize = do - maxSizeTVar <- newTVarIO maxSize - JwtCache key maxSizeTVar <$> - notCachingErrors (readTVar maxSizeTVar) key + maxSizeTVar <- newTVarIO maxSize + Caching key config maxSizeTVar <$> + newCache + (readTVar maxSizeTVar) key observationHandler config - notCachingErrors :: STM Int -> JwkSet -> IO (SC.Cache (ExceptT Error IO) ByteString JSON.Object) - notCachingErrors maxSize key = SC.cacheIO (SC.CacheConfig maxSize - (parseAndDecodeClaims key) - (lift . observationHandler . JwtCacheLookup) -- lookup metrics - (const . const $ lift $ observationHandler JwtCacheEviction) -- evictions metrics - alwaysValid) -- no invalidation for now +parseAndValidateClaims :: UTCTime -> ByteString -> CacheState -> ExceptT Error IO (Validated [Aud, Time] JSON.Object) +parseAndValidateClaims time k c = (parseAndValidateAud c >>> validateTimeClaims time) k -lookupJwtCache :: JwtCacheState -> Maybe ByteString -> ExceptT Error IO JSON.Object -lookupJwtCache (JwtCacheState _ cacheState) k = liftIO (readIORef cacheState) >>= flip (maybe (pure KM.empty)) k . decode +lookupJwtCache :: JwtCacheState -> UTCTime -> ByteString -> ExceptT Error IO JSON.Object +lookupJwtCache (JwtCacheState _ cacheState) time k = + liftIO (readIORef cacheState) >>= fmap getValidated . parseAndValidateClaims time k diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index 2bfcb9e93b..38cfe66847 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -19,6 +19,7 @@ module PostgREST.Config , LogLevel(..) , OpenAPIMode(..) , Proxy(..) + , CfgAud , toText , isMalformedProxyUri , readAppConfig @@ -29,6 +30,8 @@ module PostgREST.Config , addTargetSessionAttrs , exampleConfigFile , audMatchesCfg + , defaultCfgAud + , parseCfgAud ) where import qualified Data.Aeson as JSON @@ -50,7 +53,7 @@ import Data.List.NonEmpty (fromList, toList) import Data.Maybe (fromJust) import Data.Scientific (floatingOrInteger) import Jose.Jwk (Jwk, JwkSet) -import Network.URI (escapeURIString, isURI, +import Network.URI (escapeURIString, isUnescapedInURIComponent) import Numeric (readOct, showOct) import System.Environment (getEnvironment) @@ -66,10 +69,31 @@ import PostgREST.Config.Proxy (Proxy (..), import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier, dumpQi, toQi) -import Protolude hiding (Proxy, toList) +import Protolude hiding (Proxy, toList) +import qualified Text.Regex.TDFA as R + +data ParsedValue a b = ParsedValue { + sourceValue :: a, + parsedValue :: b +} +instance Eq a => Eq (ParsedValue a b) where + x == y = sourceValue x == sourceValue y + +newtype CfgAud = CfgAud { unCfgAud :: ParsedValue (Maybe Text) R.Regex } deriving Eq + +parseCfgAud :: MonadFail m => Text -> m CfgAud +parseCfgAud = fmap CfgAud . (fmap . ParsedValue . Just <*> parseRegex) + where + parseRegex = maybe (fail "jwt-aud should be a valid regular expression") pure . R.makeRegexM . bounded + -- need start and end of text bounds so that + -- regex does not match parts of text + bounded = ("\\`(" <>) . (<> "\\')") + +defaultCfgAud :: CfgAud +defaultCfgAud = CfgAud $ ParsedValue Nothing $ R.makeRegex (".*"::Text) audMatchesCfg :: AppConfig -> Text -> Bool -audMatchesCfg = maybe (const True) (==) . configJwtAudience +audMatchesCfg = R.matchTest . parsedValue . unCfgAud . configJwtAudience data AppConfig = AppConfig { configAppSettings :: [(Text, Text)] @@ -97,7 +121,7 @@ data AppConfig = AppConfig , configDbUri :: Text , configFilePath :: Maybe FilePath , configJWKS :: Maybe JwkSet - , configJwtAudience :: Maybe Text + , configJwtAudience :: CfgAud , configJwtRoleClaimKey :: JSPath , configJwtSecret :: Maybe BS.ByteString , configJwtSecretIsBase64 :: Bool @@ -171,7 +195,7 @@ toText conf = ,("db-pre-config", q . maybe mempty dumpQi . configDbPreConfig) ,("db-tx-end", q . showTxEnd) ,("db-uri", q . configDbUri) - ,("jwt-aud", q . fromMaybe mempty . configJwtAudience) + ,("jwt-aud", q . fold . sourceValue . unCfgAud . configJwtAudience) ,("jwt-role-claim-key", q . T.intercalate mempty . fmap dumpJSPath . configJwtRoleClaimKey) ,("jwt-secret", q . T.decodeUtf8 . showJwtSecret) ,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64) @@ -279,7 +303,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = <*> (fromMaybe "postgresql://" <$> optString "db-uri") <*> pure optPath <*> pure Nothing - <*> optStringOrURI "jwt-aud" + <*> (optStringEmptyable "jwt-aud" >>= maybe (pure defaultCfgAud) parseCfgAud) <*> parseRoleClaimKey "jwt-role-claim-key" "role-claim-key" <*> (fmap encodeUtf8 <$> optString "jwt-secret") <*> (fromMaybe False <$> optWithAlias @@ -399,20 +423,6 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = optStringEmptyable :: C.Key -> C.Parser C.Config (Maybe Text) optStringEmptyable k = overrideFromDbOrEnvironment C.optional k coerceText - optStringOrURI :: C.Key -> C.Parser C.Config (Maybe Text) - optStringOrURI k = do - stringOrURI <- mfilter (/= "") <$> overrideFromDbOrEnvironment C.optional k coerceText - -- If the string contains ':' then it should - -- be a valid URI according to RFC 3986 - case stringOrURI of - Just s -> if T.isInfixOf ":" s then validateURI s else return (Just s) - Nothing -> return Nothing - where - validateURI :: Text -> C.Parser C.Config (Maybe Text) - validateURI s = if isURI (T.unpack s) - then return $ Just s - else fail "jwt-aud should be a string or a valid URI" - optInt :: (Read i, Integral i) => C.Key -> C.Parser C.Config (Maybe i) optInt k = join <$> overrideFromDbOrEnvironment C.optional k coerceInt diff --git a/test/io/fixtures.yaml b/test/io/fixtures.yaml index ee00182cc0..68892f5dfb 100644 --- a/test/io/fixtures.yaml +++ b/test/io/fixtures.yaml @@ -45,7 +45,7 @@ cli: expect: error use_defaultenv: true env: - PGRST_JWT_AUD: 'http://%%localhorst.invalid' + PGRST_JWT_AUD: '[' - name: invalid log-level expect: error use_defaultenv: true diff --git a/test/io/test_cli.py b/test/io/test_cli.py index 0f46e4c0b0..3b48cc0b22 100644 --- a/test/io/test_cli.py +++ b/test/io/test_cli.py @@ -266,15 +266,15 @@ def test_schema_cache_snapshot(baseenv, key, snapshot_yaml): assert formatted == snapshot_yaml -def test_jwt_aud_config_set_to_invalid_uri(defaultenv): - "PostgREST should exit with an error message in output if jwt-aud config is set to an invalid URI" +def test_jwt_aud_config_set_to_invalid_regex(defaultenv): + "PostgREST should exit with an error message in output if jwt-aud config is set to an invalid regular expression" env = { **defaultenv, - "PGRST_JWT_AUD": "foo://%%$$^^.com", + "PGRST_JWT_AUD": "[", } error = cli(["--dump-config"], env=env, expect_error=True) - assert "jwt-aud should be a string or a valid URI" in error + assert "jwt-aud should be a valid regular expression" in error def test_jwt_secret_min_length(defaultenv): diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 3c48a1134d..c05df0e2db 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -34,10 +34,12 @@ import PostgREST.Config (AppConfig (..), JSPathExp (..), LogLevel (..), OpenAPIMode (..), + defaultCfgAud, parseCfgAud, parseSecret) import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) import Protolude hiding (get, toS) import Protolude.Conv (toS) +import Protolude.Partial (fromJust) filterAndMatchCT :: BS.ByteString -> MatchHeader filterAndMatchCT val = MatchHeader $ \headers _ -> @@ -135,7 +137,7 @@ baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in , configDbUri = "postgresql://" , configFilePath = Nothing , configJWKS = rightToMaybe $ parseSecret secret - , configJwtAudience = Nothing + , configJwtAudience = defaultCfgAud , configJwtRoleClaimKey = [JSPKey "role"] , configJwtSecret = Just secret , configJwtSecretIsBase64 = False @@ -218,7 +220,8 @@ testCfgAudienceJWT :: AppConfig testCfgAudienceJWT = baseCfg { configJwtSecret = Just generateSecret - , configJwtAudience = Just "youraudience" + -- parseCfgAud might fail on invalid regex but it is safe here + , configJwtAudience = fromJust $ parseCfgAud "urn..uriaudience|youraudience" , configJWKS = rightToMaybe $ parseSecret generateSecret }