From a9120f01ac13c0ae35172a17326c30f1f2df9220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Tue, 28 Oct 2025 11:35:16 +0100 Subject: [PATCH 1/4] refactor: JWT code reorganization to facilitate more granular caching decisions JWT cache implementation introduced two new modules: Auth.Jwt and Auth.JwtCache. This refactoring reorganizes code in Auth and the above two modules so that reponsibilities and dependencies are more clear: * parseClaims function was moved from Auth.Jwt back to Auth. Thanks to it Auth.Jwt module became independent from AuthResult data structure and role handling. Its only purpose right now is to parse/verify tokens and validate claims * validateClaims function in Auth.Jwt module was split to separate validateAud and validateTimeClaims functions. This change was necessary to allow Auth.JwtCache module to be the only place to decide what validations are cached. * Introduced type level tagging of claim validation results so that it is possible to statically ensure all required validations were performed (see Auth.JwtCache.parseAndValidateClaims signature) * Made Auth.Jwt module independent from Config module: validateAud no longer takes Config as an argument but a (Text -> Bool) function to validate audience values * Auth.JwtCache module was changed so that it is now possilble to cache claims validation results. Tagged claim validation result types are used to ensure all validations are performed regardless of the decision about what should be cached. * JwtCache datatype in Auth.JwtCache module was renamed to CacheState with JwksNotConfigured, NotCaching and Caching constructors. * Creation of a Sieve cache instance was moved to a CacheVariant typeclass function newCache * NeedsReconfiguration typeclass was introduced to handle differences between different CacheVariants in deciding when cache reset is needed (if aud claim validation results are cached we need to reset cache when jwt-aud changes) --- src/PostgREST/Auth.hs | 53 ++++++++++-- src/PostgREST/Auth/Jwt.hs | 145 ++++++++++++++++----------------- src/PostgREST/Auth/JwtCache.hs | 138 ++++++++++++++++++------------- 3 files changed, 197 insertions(+), 139 deletions(-) diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index 7b5b9bd241..134a79d7c7 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RecordWildCards #-} {-| Module : PostgREST.Auth Description : PostgREST authentication functions. @@ -30,13 +32,19 @@ import System.TimeIt (timeItT) import PostgREST.AppState (AppState, getConfig, getJwtCacheState, getTime) -import PostgREST.Auth.Jwt (parseClaims) import PostgREST.Auth.JwtCache (lookupJwtCache) import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Config (AppConfig (..)) -import PostgREST.Error (Error (..)) +import PostgREST.Config (AppConfig (..), FilterExp (..), + JSPath, JSPathExp (..)) +import PostgREST.Error (Error (..), JwtError (..)) -import Protolude +import Control.Monad.Except (liftEither) +import qualified Data.Aeson as JSON +import qualified Data.Aeson.Key as K +import qualified Data.Aeson.KeyMap as KM +import qualified Data.Text as T +import qualified Data.Vector as V +import Protolude -- | Validate authorization header -- Parse and store JWT claims for future use in the request. @@ -46,7 +54,8 @@ middleware appState app req respond = do time <- getTime appState let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) - parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time + parseToken = maybe (pure KM.empty) (lookupJwtCache jwtCacheState time) + parseJwt = runExceptT $ parseToken >=> parseClaims conf $ token jwtCacheState = getJwtCacheState appState -- If ServerTimingEnabled -> calculate JWT validation time @@ -59,6 +68,38 @@ middleware appState app req respond = do app req' respond +parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> JSON.Object -> m AuthResult +parseClaims AppConfig{configJwtRoleClaimKey, configDbAnonRole} mclaims = do + -- role defaults to anon if not specified in jwt + role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $ + unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole + pure AuthResult + { authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role) + , authRole = role + } + where + walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value + walkJSPath x [] = x + walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest + walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest + walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar + walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar + walkJSPath _ _ = Nothing + + findFirstMatch matchWith pattern = foldr checkMatch Nothing + where + checkMatch (JSON.String txt) acc + | pattern `matchWith` txt = Just $ JSON.String txt + | otherwise = acc + checkMatch _ acc = acc + + unquoted :: JSON.Value -> BS.ByteString + unquoted (JSON.String t) = encodeUtf8 t + unquoted v = BS.toStrict $ JSON.encode v + authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey {-# NOINLINE authResultKey #-} diff --git a/src/PostgREST/Auth/Jwt.hs b/src/PostgREST/Auth/Jwt.hs index 734f302ab0..3b83cf6676 100644 --- a/src/PostgREST/Auth/Jwt.hs +++ b/src/PostgREST/Auth/Jwt.hs @@ -4,28 +4,30 @@ Description : PostgREST JWT support functions. This module provides functions to deal with JWT parsing and validation (http://jwt.io). -} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module PostgREST.Auth.Jwt - ( parseAndDecodeClaims - , parseClaims) where - -import qualified Data.Aeson as JSON -import qualified Data.Aeson.Key as K -import qualified Data.Aeson.KeyMap as KM -import qualified Data.ByteString as BS -import qualified Data.ByteString.Internal as BS -import qualified Data.ByteString.Lazy.Char8 as LBS -import qualified Data.Scientific as Sci -import qualified Data.Text as T -import qualified Data.Vector as V -import qualified Jose.Jwk as JWT -import qualified Jose.Jwt as JWT + ( Validation (..) + , Validated (getValidated) + , parseAndDecodeClaims + , validateAud + , validateTimeClaims + , (>>>)) where + +import qualified Data.Aeson as JSON +import qualified Data.Aeson.KeyMap as KM +import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS +import qualified Data.Scientific as Sci +import qualified Jose.Jwk as JWT +import qualified Jose.Jwt as JWT import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) @@ -33,39 +35,77 @@ import Data.Text () import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import PostgREST.Auth.Types (AuthResult (..)) -import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, - JSPathExp (..), audMatchesCfg) -import PostgREST.Error (Error (..), - JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed), - JwtDecodeError (..), JwtError (..)) +import PostgREST.Error (Error (..), + JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed), + JwtDecodeError (..), JwtError (..)) import Data.Aeson ((.:?)) import Data.Aeson.Types (parseMaybe) +import Data.Coerce (coerce) import Jose.Jwk (JwkSet) import Protolude hiding (first) +-- A value tagged by a type-level list of validations pefrormed on it +newtype Validated (k :: [v]) a = Validated { getValidated :: a } + +-- Helper to implement type safe validation chaining +type family (++) (lst::[k]) lst' where + '[] ++ lst = lst + (l : ls) ++ lst = l : (ls ++ lst) + +-- Validation chaining operator +(>>>) :: (Monad m, Coercible (m (Validated kc c)) (m (Validated (kb ++ kc) c))) + => (a -> m (Validated kb b)) + -> (b -> m (Validated kc c)) + -> a + -> m (Validated (kb ++ kc) c) +f >>> g = coerce . (f >=> g . coerce) + parseAndDecodeClaims :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JSON.Object -parseAndDecodeClaims jwkSet token = parseToken jwkSet token >>= decodeClaims +parseAndDecodeClaims jwkSet = parseToken jwkSet >=> decodeClaims + +data Validation = Aud | Time + +validateAud :: MonadError Error m => (Text -> Bool) -> JSON.Object -> m (Validated '[Aud] JSON.Object) +validateAud = validate . checkAud + +validateTimeClaims :: MonadError Error m => UTCTime -> JSON.Object -> m (Validated '[Time] JSON.Object) +validateTimeClaims = validate . checkExpNbfIat decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims) decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType -validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m () -validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims) +validate :: MonadError Error m => (t -> Alt Maybe JwtClaimsError) -> t -> m (Validated k t) +validate f claims = fmap Validated $ liftEither $ maybeToLeft claims $ fmap JwtErr . getAlt $ JwtClaimsErr <$> f claims data ValidAud = VAString Text | VAArray [Text] deriving Generic instance JSON.FromJSON ValidAud where parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } -checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError -checkForErrors time audMatches = mconcat +claim :: (JSON.FromJSON a, Applicative f, Monoid (f p)) => KM.Key -> p -> (a -> f p) -> JSON.Object -> f p +claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key) + +checkValue :: (Applicative f, Monoid (f p)) => (t -> Bool) -> p -> t -> f p +checkValue invalid msg val = + if invalid val then + pure msg + else + mempty + +checkAud :: (Applicative f, Monoid (f JwtClaimsError)) => (Text -> Bool) -> JSON.Object -> f JwtClaimsError +checkAud audMatches = claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience + where + validAud = \case + (VAString aud) -> audMatches aud + (VAArray auds) -> null auds || any audMatches auds + +checkExpNbfIat :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> JSON.Object -> m JwtClaimsError +checkExpNbfIat time = mconcat [ claim "exp" ExpClaimNotNumber $ inThePast JWTExpired , claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid , claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture - , claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience ] where allowedSkewSeconds = 30 :: Int64 @@ -78,20 +118,6 @@ checkForErrors time audMatches = mconcat checkTime cond = checkValue (cond. sciToInt) - validAud = \case - (VAString aud) -> audMatches aud - (VAArray auds) -> null auds || any audMatches auds - - checkValue invalid msg val = - if invalid val then - pure msg - else - mempty - - claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key) - --- | Receives the JWT secret and audience (from config) and a JWT and returns a --- JSON object of JWT claims. parseToken :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JWT.JwtContent parseToken _ "" = throwError $ JwtErr $ JwtDecodeErr EmptyAuthHeader parseToken secret tkn = do @@ -116,36 +142,3 @@ parseToken secret tkn = do jwtDecodeError JWT.BadCrypto = JwtDecodeErr BadCrypto -- Control never reaches here, the decode function only returns the above three jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError - -parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult -parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do - validateClaims time (audMatchesCfg cfg) mclaims - -- role defaults to anon if not specified in jwt - role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $ - unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole - pure AuthResult - { authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role) - , authRole = role - } - where - walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value - walkJSPath x [] = x - walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest - walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar - walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar - walkJSPath _ _ = Nothing - - findFirstMatch matchWith pattern = foldr checkMatch Nothing - where - checkMatch (JSON.String txt) acc - | pattern `matchWith` txt = Just $ JSON.String txt - | otherwise = acc - checkMatch _ acc = acc - - unquoted :: JSON.Value -> BS.ByteString - unquoted (JSON.String t) = encodeUtf8 t - unquoted v = LBS.toStrict $ JSON.encode v diff --git a/src/PostgREST/Auth/JwtCache.hs b/src/PostgREST/Auth/JwtCache.hs index 395ecd48d4..86234b1a68 100644 --- a/src/PostgREST/Auth/JwtCache.hs +++ b/src/PostgREST/Auth/JwtCache.hs @@ -1,15 +1,20 @@ {-| -Module : PostgREST.Auth.JwtCache +Module : PostgREST.Auth.Caching Description : PostgREST JWT validation results Cache. This module provides functions to deal with the JWT cache. -} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} module PostgREST.Auth.JwtCache ( init @@ -18,97 +23,116 @@ module PostgREST.Auth.JwtCache , lookupJwtCache ) where -import qualified Data.Aeson as JSON -import qualified Data.Aeson.KeyMap as KM +import qualified Data.Aeson as JSON import PostgREST.Error (Error (..), JwtError (JwtSecretMissing)) import Control.Concurrent.STM (newTVarIO, readTVar, writeTVar) import Control.Concurrent.STM.TVar (TVar) -import Control.Monad.Error.Class (liftEither) import Data.ByteString hiding (all, init) import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import Data.Time.Clock (UTCTime) import Jose.Jwk (JwkSet) -import PostgREST.Auth.Jwt (parseAndDecodeClaims) +import PostgREST.Auth.Jwt (Validated (getValidated), + Validation (..), + parseAndDecodeClaims, + validateAud, + validateTimeClaims, + (>>>)) import PostgREST.Cache.Sieve (alwaysValid) import qualified PostgREST.Cache.Sieve as SC -import PostgREST.Config (AppConfig (..)) +import PostgREST.Config (AppConfig (..), + audMatchesCfg) import PostgREST.Observation (Observation (JwtCacheEviction, JwtCacheLookup), ObservationHandler) import Protolude -data JwtCacheState = JwtCacheState ObservationHandler (IORef JwtCache) +type Cache m v = SC.Cache m ByteString v -class CacheVariant m v where - cached :: SC.Cache m ByteString v -> ByteString -> ExceptT Error IO JSON.Object +type SelectedCacheVariant = Cache (ExceptT Error IO) (Validated '[Aud] JSON.Object) -{-| -Jwt caching can have three different configurations: -* missing JWT Key (no caching and throw error when JWT token present in the request) -* JWT cache turned off -* JWT cache turned on +data CacheState = + JwksNotConfigured | + NotCaching JwkSet (Text -> Bool) | + Caching JwkSet AppConfig (TVar Int) SelectedCacheVariant -All three options are represented by JwtCache data type. +data JwtCacheState = JwtCacheState ObservationHandler (IORef CacheState) -Handling of reconfiguration is centralized in this module. --} -data JwtCache = - JwtNoJwks | - JwtNoCache JwkSet | - forall m v. CacheVariant m v => JwtCache JwkSet (TVar Int) (SC.Cache m ByteString v) +parseAndValidateAud :: CacheState -> ByteString -> ExceptT Error IO (Validated '[Aud] JSON.Object) +parseAndValidateAud (Caching _ config _ c) = lookup config c +parseAndValidateAud (NotCaching key audMatches) = parseAndDecodeClaims key >=> validateAud audMatches +parseAndValidateAud JwksNotConfigured = const $ throwError (JwtErr JwtSecretMissing) + +class NeedsReinitialize v where + needsReinitialize :: AppConfig -> AppConfig -> Bool + +instance NeedsReinitialize JSON.Object where + needsReinitialize _ _ = False -instance CacheVariant IO (Either Error JSON.Object) where - cached c = lift . SC.cached c >=> liftEither +instance NeedsReinitialize (Validated (Aud : rest) JSON.Object) where + needsReinitialize old new = configJwtAudience old /= configJwtAudience new -instance CacheVariant (ExceptT Error IO) JSON.Object where - cached = SC.cached +instance NeedsReinitialize v => NeedsReinitialize (SC.Cache m k v) where + needsReinitialize = needsReinitialize @v -decode :: JwtCache -> ByteString -> ExceptT Error IO JSON.Object -decode JwtNoJwks = const $ throwError (JwtErr JwtSecretMissing) -decode (JwtNoCache key) = parseAndDecodeClaims key -decode (JwtCache _ _ c) = cached c +class CacheVariant c where + newCache :: STM Int -> JwkSet -> ObservationHandler -> AppConfig -> IO c + lookup :: AppConfig -> c -> ByteString -> ExceptT Error IO (Validated '[Aud] JSON.Object) + +-- Cache parsed JWTs with valid signature and valid aud +instance CacheVariant (Cache (ExceptT Error IO) (Validated '[Aud] JSON.Object)) where + newCache maxSize key observationHandler config = SC.cacheIO (SC.CacheConfig maxSize + (parseAndDecodeClaims key >=> validateAud (audMatchesCfg config)) + (lift . observationHandler . JwtCacheLookup) -- lookup metrics + (const . const $ lift $ observationHandler JwtCacheEviction) -- evictions metrics + alwaysValid) -- no invalidation for now + lookup _ = SC.cached -- | Reconfigure JWT caching and update JwtCacheState accordingly update :: JwtCacheState -> AppConfig -> IO () -update (JwtCacheState observationHandler jwtCacheState) config@AppConfig{configJWKS, configJwtCacheMaxEntries} = - let reinitialize = - newJwtCache config observationHandler - >>= writeIORef jwtCacheState - in - readIORef jwtCacheState >>= \case - (JwtCache decodingKey maxSize _) -> - if configJWKS /= Just decodingKey || configJwtCacheMaxEntries <= 0 then - -- reinitialize if key changed or cache disabled +update (JwtCacheState observationHandler ref) config@AppConfig{configJWKS, configJwtCacheMaxEntries} = + readIORef ref >>= \case + (Caching decodingKey oldConfig maxSize cache) -> + if configJWKS /= Just decodingKey || + configJwtCacheMaxEntries <= 0 || + needsReinitialize @SelectedCacheVariant oldConfig config + then + -- reinitialize if key changed or cache disabled or the cache requires reinitialization reinitialize - else - -- max size changed - set it and let the cache shrink itself if necessary + else do + -- max size changed - set it and let the cache resize itself if necessary atomically $ writeTVar maxSize configJwtCacheMaxEntries - + -- save new config for future updates + writeIORef ref $ Caching decodingKey config maxSize cache _ -> reinitialize + where + reinitialize = newJwtCacheState config observationHandler >>= writeIORef ref init :: AppConfig -> ObservationHandler -> IO JwtCacheState -init config = fmap (<$>) JwtCacheState <*> (newJwtCache config >=> newIORef) +init config = fmap (<$>) JwtCacheState <*> (newJwtCacheState config >=> newIORef) -- | Initialize JwtCacheState -newJwtCache :: AppConfig -> ObservationHandler -> IO JwtCache -newJwtCache AppConfig{configJWKS, configJwtCacheMaxEntries} observationHandler = do - maybe (pure JwtNoJwks) initCache configJWKS +newJwtCacheState :: AppConfig -> ObservationHandler -> IO CacheState +newJwtCacheState config@AppConfig{configJWKS, configJwtCacheMaxEntries} observationHandler = + maybe (pure JwksNotConfigured) initCache configJWKS where - initCache key = if configJwtCacheMaxEntries <= 0 then pure (JwtNoCache key) else createCache key configJwtCacheMaxEntries + initCache key = + if configJwtCacheMaxEntries <= 0 then + pure $ NotCaching key (audMatchesCfg config) + else + createCache key configJwtCacheMaxEntries createCache key maxSize = do - maxSizeTVar <- newTVarIO maxSize - JwtCache key maxSizeTVar <$> - notCachingErrors (readTVar maxSizeTVar) key + maxSizeTVar <- newTVarIO maxSize + Caching key config maxSizeTVar <$> + newCache + (readTVar maxSizeTVar) key observationHandler config - notCachingErrors :: STM Int -> JwkSet -> IO (SC.Cache (ExceptT Error IO) ByteString JSON.Object) - notCachingErrors maxSize key = SC.cacheIO (SC.CacheConfig maxSize - (parseAndDecodeClaims key) - (lift . observationHandler . JwtCacheLookup) -- lookup metrics - (const . const $ lift $ observationHandler JwtCacheEviction) -- evictions metrics - alwaysValid) -- no invalidation for now +parseAndValidateClaims :: UTCTime -> ByteString -> CacheState -> ExceptT Error IO (Validated [Aud, Time] JSON.Object) +parseAndValidateClaims time k c = (parseAndValidateAud c >>> validateTimeClaims time) k -lookupJwtCache :: JwtCacheState -> Maybe ByteString -> ExceptT Error IO JSON.Object -lookupJwtCache (JwtCacheState _ cacheState) k = liftIO (readIORef cacheState) >>= flip (maybe (pure KM.empty)) k . decode +lookupJwtCache :: JwtCacheState -> UTCTime -> ByteString -> ExceptT Error IO JSON.Object +lookupJwtCache (JwtCacheState _ cacheState) time k = + liftIO (readIORef cacheState) >>= fmap getValidated . parseAndValidateClaims time k From 1c9d633e941406e4c4b60fdb438bb568f0899038 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Wed, 29 Oct 2025 10:27:45 +0100 Subject: [PATCH 2/4] change: Make jwt-aud config value a regular expression This change adds flexibility to aud claim validation. jwt-aud configuration property can now be specified as a regular expression. For example, it is now possible to * configure multiple acceptable aud values with '|' regex operator, eg: 'audience1|audience2|audience3' * accept any audience from a particular domain, eg: 'https://[a-z0-9]*\.example\.com' --- CHANGELOG.md | 1 + docs/references/auth.rst | 7 ++++- docs/references/configuration.rst | 6 ++-- src/PostgREST/Config.hs | 50 ++++++++++++++++++------------- test/io/fixtures.yaml | 2 +- test/io/test_cli.py | 8 ++--- test/spec/SpecHelper.hs | 7 +++-- 7 files changed, 50 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 685a564c28..e7b377417a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ All notable changes to this project will be documented in this file. From versio - From now on PostgREST will follow a `MAJOR.PATCH` two-part versioning. Only even-numbered MAJOR versions will be released, reserving odd-numbered MAJOR versions for development. - Replaced `jwt-cache-max-lifetime` config with `jwt-cache-max-entries` by @mkleczek in #4084 - `log-query` config now takes a boolean instead of a string value by @steve-chavez in #3934 +- `jwt-aud` config now takes a regular expression to match against `aud` claim #2099 ## [13.0.8] - 2025-10-24 diff --git a/docs/references/auth.rst b/docs/references/auth.rst index b92bd8c210..81625e214e 100644 --- a/docs/references/auth.rst +++ b/docs/references/auth.rst @@ -193,13 +193,18 @@ PostgREST has built-in validation of the `JWT audience claim ` error. + If the ``aud`` key **is not present** or if its value is ``null`` or ``[]``, PostgREST will interpret this token as allowed for all audiences and will complete the request. +Examples: +- To make PostgREST accept ``aud`` claim value from a set ``audience1``, ``audience2``, ``otheraudience``, :ref:`jwt-aud` claim should be set to ``audience1|audience2|otheraudience``. +- To make PostgREST accept ``aud`` claim value matching any ``https`` URI pointing to a host in ``example.com`` domain, :ref:`jwt-aud` claim should be set to ``https://[a-zA-Z0-9_]*\.example\.com``. +- To make PostgREST accept any ``aud`` claim value , :ref:`jwt-aud` claim should be set to ``.*`` (which is the default). + .. _jwt_caching: JWT Cache diff --git a/docs/references/configuration.rst b/docs/references/configuration.rst index c3b563f1f2..11892904c3 100644 --- a/docs/references/configuration.rst +++ b/docs/references/configuration.rst @@ -596,14 +596,14 @@ jwt-aud ------- =============== ================================= - **Type** String - **Default** `n/a` + **Type** String (must be a valid regular expression) + **Default** `.*` **Reloadable** Y **Environment** PGRST_JWT_AUD **In-Database** pgrst.jwt_aud =============== ================================= - Specifies an audience for the JWT ``aud`` claim. See :ref:`jwt_aud`. + Specifies a regular expression to match against the JWT ``aud`` claim. See :ref:`jwt_aud`. .. _jwt-role-claim-key: diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index 2bfcb9e93b..38cfe66847 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -19,6 +19,7 @@ module PostgREST.Config , LogLevel(..) , OpenAPIMode(..) , Proxy(..) + , CfgAud , toText , isMalformedProxyUri , readAppConfig @@ -29,6 +30,8 @@ module PostgREST.Config , addTargetSessionAttrs , exampleConfigFile , audMatchesCfg + , defaultCfgAud + , parseCfgAud ) where import qualified Data.Aeson as JSON @@ -50,7 +53,7 @@ import Data.List.NonEmpty (fromList, toList) import Data.Maybe (fromJust) import Data.Scientific (floatingOrInteger) import Jose.Jwk (Jwk, JwkSet) -import Network.URI (escapeURIString, isURI, +import Network.URI (escapeURIString, isUnescapedInURIComponent) import Numeric (readOct, showOct) import System.Environment (getEnvironment) @@ -66,10 +69,31 @@ import PostgREST.Config.Proxy (Proxy (..), import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier, dumpQi, toQi) -import Protolude hiding (Proxy, toList) +import Protolude hiding (Proxy, toList) +import qualified Text.Regex.TDFA as R + +data ParsedValue a b = ParsedValue { + sourceValue :: a, + parsedValue :: b +} +instance Eq a => Eq (ParsedValue a b) where + x == y = sourceValue x == sourceValue y + +newtype CfgAud = CfgAud { unCfgAud :: ParsedValue (Maybe Text) R.Regex } deriving Eq + +parseCfgAud :: MonadFail m => Text -> m CfgAud +parseCfgAud = fmap CfgAud . (fmap . ParsedValue . Just <*> parseRegex) + where + parseRegex = maybe (fail "jwt-aud should be a valid regular expression") pure . R.makeRegexM . bounded + -- need start and end of text bounds so that + -- regex does not match parts of text + bounded = ("\\`(" <>) . (<> "\\')") + +defaultCfgAud :: CfgAud +defaultCfgAud = CfgAud $ ParsedValue Nothing $ R.makeRegex (".*"::Text) audMatchesCfg :: AppConfig -> Text -> Bool -audMatchesCfg = maybe (const True) (==) . configJwtAudience +audMatchesCfg = R.matchTest . parsedValue . unCfgAud . configJwtAudience data AppConfig = AppConfig { configAppSettings :: [(Text, Text)] @@ -97,7 +121,7 @@ data AppConfig = AppConfig , configDbUri :: Text , configFilePath :: Maybe FilePath , configJWKS :: Maybe JwkSet - , configJwtAudience :: Maybe Text + , configJwtAudience :: CfgAud , configJwtRoleClaimKey :: JSPath , configJwtSecret :: Maybe BS.ByteString , configJwtSecretIsBase64 :: Bool @@ -171,7 +195,7 @@ toText conf = ,("db-pre-config", q . maybe mempty dumpQi . configDbPreConfig) ,("db-tx-end", q . showTxEnd) ,("db-uri", q . configDbUri) - ,("jwt-aud", q . fromMaybe mempty . configJwtAudience) + ,("jwt-aud", q . fold . sourceValue . unCfgAud . configJwtAudience) ,("jwt-role-claim-key", q . T.intercalate mempty . fmap dumpJSPath . configJwtRoleClaimKey) ,("jwt-secret", q . T.decodeUtf8 . showJwtSecret) ,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64) @@ -279,7 +303,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = <*> (fromMaybe "postgresql://" <$> optString "db-uri") <*> pure optPath <*> pure Nothing - <*> optStringOrURI "jwt-aud" + <*> (optStringEmptyable "jwt-aud" >>= maybe (pure defaultCfgAud) parseCfgAud) <*> parseRoleClaimKey "jwt-role-claim-key" "role-claim-key" <*> (fmap encodeUtf8 <$> optString "jwt-secret") <*> (fromMaybe False <$> optWithAlias @@ -399,20 +423,6 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = optStringEmptyable :: C.Key -> C.Parser C.Config (Maybe Text) optStringEmptyable k = overrideFromDbOrEnvironment C.optional k coerceText - optStringOrURI :: C.Key -> C.Parser C.Config (Maybe Text) - optStringOrURI k = do - stringOrURI <- mfilter (/= "") <$> overrideFromDbOrEnvironment C.optional k coerceText - -- If the string contains ':' then it should - -- be a valid URI according to RFC 3986 - case stringOrURI of - Just s -> if T.isInfixOf ":" s then validateURI s else return (Just s) - Nothing -> return Nothing - where - validateURI :: Text -> C.Parser C.Config (Maybe Text) - validateURI s = if isURI (T.unpack s) - then return $ Just s - else fail "jwt-aud should be a string or a valid URI" - optInt :: (Read i, Integral i) => C.Key -> C.Parser C.Config (Maybe i) optInt k = join <$> overrideFromDbOrEnvironment C.optional k coerceInt diff --git a/test/io/fixtures.yaml b/test/io/fixtures.yaml index ee00182cc0..68892f5dfb 100644 --- a/test/io/fixtures.yaml +++ b/test/io/fixtures.yaml @@ -45,7 +45,7 @@ cli: expect: error use_defaultenv: true env: - PGRST_JWT_AUD: 'http://%%localhorst.invalid' + PGRST_JWT_AUD: '[' - name: invalid log-level expect: error use_defaultenv: true diff --git a/test/io/test_cli.py b/test/io/test_cli.py index 0f46e4c0b0..3b48cc0b22 100644 --- a/test/io/test_cli.py +++ b/test/io/test_cli.py @@ -266,15 +266,15 @@ def test_schema_cache_snapshot(baseenv, key, snapshot_yaml): assert formatted == snapshot_yaml -def test_jwt_aud_config_set_to_invalid_uri(defaultenv): - "PostgREST should exit with an error message in output if jwt-aud config is set to an invalid URI" +def test_jwt_aud_config_set_to_invalid_regex(defaultenv): + "PostgREST should exit with an error message in output if jwt-aud config is set to an invalid regular expression" env = { **defaultenv, - "PGRST_JWT_AUD": "foo://%%$$^^.com", + "PGRST_JWT_AUD": "[", } error = cli(["--dump-config"], env=env, expect_error=True) - assert "jwt-aud should be a string or a valid URI" in error + assert "jwt-aud should be a valid regular expression" in error def test_jwt_secret_min_length(defaultenv): diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 3c48a1134d..c05df0e2db 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -34,10 +34,12 @@ import PostgREST.Config (AppConfig (..), JSPathExp (..), LogLevel (..), OpenAPIMode (..), + defaultCfgAud, parseCfgAud, parseSecret) import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) import Protolude hiding (get, toS) import Protolude.Conv (toS) +import Protolude.Partial (fromJust) filterAndMatchCT :: BS.ByteString -> MatchHeader filterAndMatchCT val = MatchHeader $ \headers _ -> @@ -135,7 +137,7 @@ baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in , configDbUri = "postgresql://" , configFilePath = Nothing , configJWKS = rightToMaybe $ parseSecret secret - , configJwtAudience = Nothing + , configJwtAudience = defaultCfgAud , configJwtRoleClaimKey = [JSPKey "role"] , configJwtSecret = Just secret , configJwtSecretIsBase64 = False @@ -218,7 +220,8 @@ testCfgAudienceJWT :: AppConfig testCfgAudienceJWT = baseCfg { configJwtSecret = Just generateSecret - , configJwtAudience = Just "youraudience" + -- parseCfgAud might fail on invalid regex but it is safe here + , configJwtAudience = fromJust $ parseCfgAud "urn..uriaudience|youraudience" , configJWKS = rightToMaybe $ parseSecret generateSecret } From d71cb81b2dc833e1418532adeaa002d8e4eb3af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Wed, 29 Oct 2025 16:00:26 +0100 Subject: [PATCH 3/4] test: Add aud claim validation to load tests --- nix/tools/generate_targets.py | 1 + nix/tools/loadtest.nix | 1 + 2 files changed, 2 insertions(+) diff --git a/nix/tools/generate_targets.py b/nix/tools/generate_targets.py index 762a2c21ab..e7b637e964 100644 --- a/nix/tools/generate_targets.py +++ b/nix/tools/generate_targets.py @@ -33,6 +33,7 @@ def generate_jwt(now: int, exp_inc: Optional[int], is_hs: bool) -> str: payload = { "sub": f"user_{random.getrandbits(32)}", "iat": now, + "aud": "veryveryveryveryverylonglonglonglonglongaudience", "role": "postgrest_test_author", } diff --git a/nix/tools/loadtest.nix b/nix/tools/loadtest.nix index 7ff786efba..6181d1403f 100644 --- a/nix/tools/loadtest.nix +++ b/nix/tools/loadtest.nix @@ -63,6 +63,7 @@ let # load test works across branches # TODO clean once PGRST_JWT_CACHE_MAX_ENTRIES merged and released export PGRST_JWT_CACHE_MAX_LIFETIME="86400" + export PGRST_JWT_AUD="audience|([a-z]er[a-z])*(long)*aud..nce" mkdir -p "$(dirname "$_arg_output")" abs_output="$(realpath "$_arg_output")" From e15e2648b2478ac4bce2f018c6ebceab145b9f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Tue, 28 Oct 2025 23:16:46 +0100 Subject: [PATCH 4/4] change: Set jwt-aud default value to \`\' (accepting only empty string) Fixes #4134 (JWT with aud claim should be rejected if jwt-aud is not set) Updated default jwt-aud value in Config module. Updated spec tests. --- docs/references/auth.rst | 2 +- src/PostgREST/Config.hs | 2 +- test/spec/Feature/Auth/AudienceJwtSecretSpec.hs | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/references/auth.rst b/docs/references/auth.rst index 81625e214e..2ebe32902b 100644 --- a/docs/references/auth.rst +++ b/docs/references/auth.rst @@ -203,7 +203,7 @@ It works this way: Examples: - To make PostgREST accept ``aud`` claim value from a set ``audience1``, ``audience2``, ``otheraudience``, :ref:`jwt-aud` claim should be set to ``audience1|audience2|otheraudience``. - To make PostgREST accept ``aud`` claim value matching any ``https`` URI pointing to a host in ``example.com`` domain, :ref:`jwt-aud` claim should be set to ``https://[a-zA-Z0-9_]*\.example\.com``. -- To make PostgREST accept any ``aud`` claim value , :ref:`jwt-aud` claim should be set to ``.*`` (which is the default). +- To make PostgREST accept any ``aud`` claim value , :ref:`jwt-aud` claim should be set to ``.*``. .. _jwt_caching: diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index 38cfe66847..e4a6544e81 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -90,7 +90,7 @@ parseCfgAud = fmap CfgAud . (fmap . ParsedValue . Just <*> parseRegex) bounded = ("\\`(" <>) . (<> "\\')") defaultCfgAud :: CfgAud -defaultCfgAud = CfgAud $ ParsedValue Nothing $ R.makeRegex (".*"::Text) +defaultCfgAud = CfgAud $ ParsedValue Nothing $ R.makeRegex ("\\`\\'"::Text) audMatchesCfg :: AppConfig -> Text -> Bool audMatchesCfg = R.matchTest . parsedValue . unCfgAud . configJwtAudience diff --git a/test/spec/Feature/Auth/AudienceJwtSecretSpec.hs b/test/spec/Feature/Auth/AudienceJwtSecretSpec.hs index d4ad2968ae..dd17e1b597 100644 --- a/test/spec/Feature/Auth/AudienceJwtSecretSpec.hs +++ b/test/spec/Feature/Auth/AudienceJwtSecretSpec.hs @@ -151,7 +151,7 @@ disabledSpec :: SpecWith ((), Application) disabledSpec = describe "test handling of aud claims in JWT when the jwt-aud config is not set" $ do context "when the audience claim is a string" $ do - it "ignores the audience claim and suceeds" $ do + it "fails when it is not empty" $ do let jwtPayload = [json|{ "exp": 9999999999, @@ -161,7 +161,7 @@ disabledSpec = describe "test handling of aud claims in JWT when the jwt-aud con }|] auth = authHeaderJWT $ generateJWT jwtPayload request methodGet "/authors_only" [auth] "" - `shouldRespondWith` 200 + `shouldRespondWith` 401 it "ignores the audience claim and suceeds when it's empty" $ do let jwtPayload = @@ -176,7 +176,7 @@ disabledSpec = describe "test handling of aud claims in JWT when the jwt-aud con `shouldRespondWith` 200 context "when the audience is an array of strings" $ do - it "ignores the audience claim and suceeds when it has 1 element" $ do + it "fails it has 1 element" $ do let jwtPayload = [json| { "exp": 9999999999, @@ -186,9 +186,9 @@ disabledSpec = describe "test handling of aud claims in JWT when the jwt-aud con }|] auth = authHeaderJWT $ generateJWT jwtPayload request methodGet "/authors_only" [auth] "" - `shouldRespondWith` 200 + `shouldRespondWith` 401 - it "ignores the audience claim and suceeds when it has more than 1 element" $ do + it "fails when it has more than 1 element" $ do let jwtPayload = [json| { "exp": 9999999999, @@ -198,7 +198,7 @@ disabledSpec = describe "test handling of aud claims in JWT when the jwt-aud con }|] auth = authHeaderJWT $ generateJWT jwtPayload request methodGet "/authors_only" [auth] "" - `shouldRespondWith` 200 + `shouldRespondWith` 401 it "ignores the audience claim and suceeds when it's empty" $ do let jwtPayload = [json|