From 04d3473eff36ad155155db8038b58deaa57945f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Tue, 21 Oct 2025 06:52:40 +0200 Subject: [PATCH 1/2] refactor: Encapsulate aud config This change is an initial step to change JWT aud configuration to regular expression. Exporting function audMatchesCfg :: AppConfig -> Text -> Bool from Config module allows changing the way how JWT aud is configured to be isolated and not affect code in Auth.JWT --- src/PostgREST/Auth/Jwt.hs | 26 ++++++++++++-------------- src/PostgREST/Config.hs | 3 +++ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/PostgREST/Auth/Jwt.hs b/src/PostgREST/Auth/Jwt.hs index 3223ff7447..db88f1be9f 100644 --- a/src/PostgREST/Auth/Jwt.hs +++ b/src/PostgREST/Auth/Jwt.hs @@ -35,7 +35,7 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) import PostgREST.Auth.Types (AuthResult (..)) import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, - JSPathExp (..)) + JSPathExp (..), audMatchesCfg) import PostgREST.Error (Error (..), JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed), JwtDecodeError (..), JwtError (..)) @@ -52,21 +52,21 @@ decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims) decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType -validateClaims :: MonadError Error m => UTCTime -> Maybe Text -> JSON.Object -> m () -validateClaims time getConfigAud claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time getConfigAud claims) +validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m () +validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims) data ValidAud = VANull | VAString Text | VAArray [Text] deriving Generic instance JSON.FromJSON ValidAud where parseJSON JSON.Null = pure VANull parseJSON o = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } o -checkForErrors :: (Monad m, forall a. Monoid (m a)) => UTCTime -> Maybe Text -> JSON.Object -> m JwtClaimsError -checkForErrors time cfgAud = mconcat +checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError +checkForErrors time audMatches = mconcat [ claim "exp" ExpClaimNotNumber $ inThePast JWTExpired , claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid , claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture - , claim "aud" AudClaimNotStringOrArray checkAud + , claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience ] where allowedSkewSeconds = 30 :: Int64 @@ -79,12 +79,10 @@ checkForErrors time cfgAud = mconcat checkTime cond = checkValue (cond. sciToInt) - checkAud = \case - (VAString aud) -> liftMaybe cfgAud >>= checkValue (aud /=) JWTNotInAudience - (VAArray auds) | (not . null) auds -> liftMaybe cfgAud >>= checkValue (not . (`elem` auds)) JWTNotInAudience - _ -> mempty - - liftMaybe = maybe mempty pure + validAud = \case + (VAString aud) -> audMatches aud + (VAArray auds) -> null auds || any audMatches auds + _ -> True checkValue invalid msg val = if invalid val then @@ -122,8 +120,8 @@ parseToken secret tkn = do jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult -parseClaims AppConfig{configJwtAudience, configJwtRoleClaimKey, configDbAnonRole} time mclaims = do - validateClaims time configJwtAudience mclaims +parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do + validateClaims time (audMatchesCfg cfg) mclaims -- role defaults to anon if not specified in jwt role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $ unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index b760ccced3..2bfcb9e93b 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -28,6 +28,7 @@ module PostgREST.Config , addFallbackAppName , addTargetSessionAttrs , exampleConfigFile + , audMatchesCfg ) where import qualified Data.Aeson as JSON @@ -67,6 +68,8 @@ import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier, dumpQi, import Protolude hiding (Proxy, toList) +audMatchesCfg :: AppConfig -> Text -> Bool +audMatchesCfg = maybe (const True) (==) . configJwtAudience data AppConfig = AppConfig { configAppSettings :: [(Text, Text)] From ad43e0458dfef31f682bdff7aa1c98ed7bc5b76c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C5=82eczek?= Date: Tue, 21 Oct 2025 12:06:03 +0200 Subject: [PATCH 2/2] refactor: Remove redundant VANull constructor in Auth.JWT module --- src/PostgREST/Auth/Jwt.hs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/PostgREST/Auth/Jwt.hs b/src/PostgREST/Auth/Jwt.hs index db88f1be9f..734f302ab0 100644 --- a/src/PostgREST/Auth/Jwt.hs +++ b/src/PostgREST/Auth/Jwt.hs @@ -55,10 +55,9 @@ decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m () validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims) -data ValidAud = VANull | VAString Text | VAArray [Text] deriving Generic +data ValidAud = VAString Text | VAArray [Text] deriving Generic instance JSON.FromJSON ValidAud where - parseJSON JSON.Null = pure VANull - parseJSON o = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } o + parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError checkForErrors time audMatches = mconcat @@ -82,7 +81,6 @@ checkForErrors time audMatches = mconcat validAud = \case (VAString aud) -> audMatches aud (VAArray auds) -> null auds || any audMatches auds - _ -> True checkValue invalid msg val = if invalid val then