Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ All notable changes to this project will be documented in this file. From versio
- From now on PostgREST will follow a `MAJOR.PATCH` two-part versioning. Only even-numbered MAJOR versions will be released, reserving odd-numbered MAJOR versions for development.
- Replaced `jwt-cache-max-lifetime` config with `jwt-cache-max-entries` by @mkleczek in #4084
- `log-query` config now takes a boolean instead of a string value by @steve-chavez in #3934
- `jwt-aud` config now takes a regular expression to match against `aud` claim #2099

## [13.0.8] - 2025-10-24

Expand Down
7 changes: 6 additions & 1 deletion docs/references/auth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,18 @@ PostgREST has built-in validation of the `JWT audience claim <https://datatracke
It works this way:

- If :ref:`jwt-aud` is not set (the default), PostgREST identifies with all audiences and allows the JWT for any ``aud`` claim.
- If :ref:`jwt-aud` is set to a specific audience, PostgREST will check if this audience is present in the ``aud`` claim:
- If :ref:`jwt-aud` is set, PostgREST will treat it as a regular expression and check if it matches the ``aud`` claim:

+ If the ``aud`` value is a JSON string, it will match it to the :ref:`jwt-aud`.
+ If the ``aud`` value is a JSON array of strings, it will search every element for a match.
+ If the match fails or if the ``aud`` value is not a string or array of strings, then the token will be rejected with a :ref:`401 Unauthorized <pgrst303>` error.
+ If the ``aud`` key **is not present** or if its value is ``null`` or ``[]``, PostgREST will interpret this token as allowed for all audiences and will complete the request.

Examples:
- To make PostgREST accept ``aud`` claim value from a set ``audience1``, ``audience2``, ``otheraudience``, :ref:`jwt-aud` claim should be set to ``audience1|audience2|otheraudience``.
- To make PostgREST accept ``aud`` claim value matching any ``https`` URI pointing to a host in ``example.com`` domain, :ref:`jwt-aud` claim should be set to ``https://[a-zA-Z0-9_]*\.example\.com``.
- To make PostgREST accept any ``aud`` claim value , :ref:`jwt-aud` claim should be set to ``.*``.

.. _jwt_caching:

JWT Cache
Expand Down
6 changes: 3 additions & 3 deletions docs/references/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -596,14 +596,14 @@ jwt-aud
-------

=============== =================================
**Type** String
**Default** `n/a`
**Type** String (must be a valid regular expression)
**Default** `.*`
**Reloadable** Y
**Environment** PGRST_JWT_AUD
**In-Database** pgrst.jwt_aud
=============== =================================

Specifies an audience for the JWT ``aud`` claim. See :ref:`jwt_aud`.
Specifies a regular expression to match against the JWT ``aud`` claim. See :ref:`jwt_aud`.

.. _jwt-role-claim-key:

Expand Down
1 change: 1 addition & 0 deletions nix/tools/generate_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def generate_jwt(now: int, exp_inc: Optional[int], is_hs: bool) -> str:
payload = {
"sub": f"user_{random.getrandbits(32)}",
"iat": now,
"aud": "veryveryveryveryverylonglonglonglonglongaudience",
"role": "postgrest_test_author",
}

Expand Down
1 change: 1 addition & 0 deletions nix/tools/loadtest.nix
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ let
# load test works across branches
# TODO clean once PGRST_JWT_CACHE_MAX_ENTRIES merged and released
export PGRST_JWT_CACHE_MAX_LIFETIME="86400"
export PGRST_JWT_AUD="audience|([a-z]er[a-z])*(long)*aud..nce"
mkdir -p "$(dirname "$_arg_output")"
abs_output="$(realpath "$_arg_output")"
Expand Down
53 changes: 47 additions & 6 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-|
Module : PostgREST.Auth
Description : PostgREST authentication functions.
Expand Down Expand Up @@ -30,13 +32,19 @@ import System.TimeIt (timeItT)

import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
getTime)
import PostgREST.Auth.Jwt (parseClaims)
import PostgREST.Auth.JwtCache (lookupJwtCache)
import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Config (AppConfig (..))
import PostgREST.Error (Error (..))
import PostgREST.Config (AppConfig (..), FilterExp (..),
JSPath, JSPathExp (..))
import PostgREST.Error (Error (..), JwtError (..))

import Protolude
import Control.Monad.Except (liftEither)
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import qualified Data.Text as T
import qualified Data.Vector as V
import Protolude

-- | Validate authorization header
-- Parse and store JWT claims for future use in the request.
Expand All @@ -46,7 +54,8 @@ middleware appState app req respond = do
time <- getTime appState

let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time
parseToken = maybe (pure KM.empty) (lookupJwtCache jwtCacheState time)
parseJwt = runExceptT $ parseToken >=> parseClaims conf $ token
jwtCacheState = getJwtCacheState appState

-- If ServerTimingEnabled -> calculate JWT validation time
Expand All @@ -59,6 +68,38 @@ middleware appState app req respond = do

app req' respond

parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> JSON.Object -> m AuthResult
parseClaims AppConfig{configJwtRoleClaimKey, configDbAnonRole} mclaims = do
-- role defaults to anon if not specified in jwt
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
pure AuthResult
{ authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role)
, authRole = role
}
where
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
walkJSPath _ _ = Nothing

findFirstMatch matchWith pattern = foldr checkMatch Nothing
where
checkMatch (JSON.String txt) acc
| pattern `matchWith` txt = Just $ JSON.String txt
| otherwise = acc
checkMatch _ acc = acc

unquoted :: JSON.Value -> BS.ByteString
unquoted (JSON.String t) = encodeUtf8 t
unquoted v = BS.toStrict $ JSON.encode v

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
{-# NOINLINE authResultKey #-}
Expand Down
145 changes: 69 additions & 76 deletions src/PostgREST/Auth/Jwt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,108 @@ Description : PostgREST JWT support functions.

This module provides functions to deal with JWT parsing and validation (http://jwt.io).
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module PostgREST.Auth.Jwt
( parseAndDecodeClaims
, parseClaims) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.Key as K
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Scientific as Sci
import qualified Data.Text as T
import qualified Data.Vector as V
import qualified Jose.Jwk as JWT
import qualified Jose.Jwt as JWT
( Validation (..)
, Validated (getValidated)
, parseAndDecodeClaims
, validateAud
, validateTimeClaims
, (>>>)) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.Scientific as Sci
import qualified Jose.Jwk as JWT
import qualified Jose.Jwt as JWT

import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.Text ()
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)

import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath,
JSPathExp (..), audMatchesCfg)
import PostgREST.Error (Error (..),
JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed),
JwtDecodeError (..), JwtError (..))
import PostgREST.Error (Error (..),
JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed),
JwtDecodeError (..), JwtError (..))

import Data.Aeson ((.:?))
import Data.Aeson.Types (parseMaybe)
import Data.Coerce (coerce)
import Jose.Jwk (JwkSet)
import Protolude hiding (first)

-- A value tagged by a type-level list of validations pefrormed on it
newtype Validated (k :: [v]) a = Validated { getValidated :: a }

-- Helper to implement type safe validation chaining
type family (++) (lst::[k]) lst' where
'[] ++ lst = lst
(l : ls) ++ lst = l : (ls ++ lst)

-- Validation chaining operator
(>>>) :: (Monad m, Coercible (m (Validated kc c)) (m (Validated (kb ++ kc) c)))
=> (a -> m (Validated kb b))
-> (b -> m (Validated kc c))
-> a
-> m (Validated (kb ++ kc) c)
f >>> g = coerce . (f >=> g . coerce)

parseAndDecodeClaims :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JSON.Object
parseAndDecodeClaims jwkSet token = parseToken jwkSet token >>= decodeClaims
parseAndDecodeClaims jwkSet = parseToken jwkSet >=> decodeClaims

data Validation = Aud | Time

validateAud :: MonadError Error m => (Text -> Bool) -> JSON.Object -> m (Validated '[Aud] JSON.Object)
validateAud = validate . checkAud

validateTimeClaims :: MonadError Error m => UTCTime -> JSON.Object -> m (Validated '[Time] JSON.Object)
validateTimeClaims = validate . checkExpNbfIat

decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object
decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims)
decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType

validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m ()
validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims)
validate :: MonadError Error m => (t -> Alt Maybe JwtClaimsError) -> t -> m (Validated k t)
validate f claims = fmap Validated $ liftEither $ maybeToLeft claims $ fmap JwtErr . getAlt $ JwtClaimsErr <$> f claims

data ValidAud = VAString Text | VAArray [Text] deriving Generic
instance JSON.FromJSON ValidAud where
parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue }

checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError
checkForErrors time audMatches = mconcat
claim :: (JSON.FromJSON a, Applicative f, Monoid (f p)) => KM.Key -> p -> (a -> f p) -> JSON.Object -> f p
claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)

checkValue :: (Applicative f, Monoid (f p)) => (t -> Bool) -> p -> t -> f p
checkValue invalid msg val =
if invalid val then
pure msg
else
mempty

checkAud :: (Applicative f, Monoid (f JwtClaimsError)) => (Text -> Bool) -> JSON.Object -> f JwtClaimsError
checkAud audMatches = claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
where
validAud = \case
(VAString aud) -> audMatches aud
(VAArray auds) -> null auds || any audMatches auds

checkExpNbfIat :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> JSON.Object -> m JwtClaimsError
checkExpNbfIat time = mconcat
[
claim "exp" ExpClaimNotNumber $ inThePast JWTExpired
, claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid
, claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture
, claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
]
where
allowedSkewSeconds = 30 :: Int64
Expand All @@ -78,20 +118,6 @@ checkForErrors time audMatches = mconcat

checkTime cond = checkValue (cond. sciToInt)

validAud = \case
(VAString aud) -> audMatches aud
(VAArray auds) -> null auds || any audMatches auds

checkValue invalid msg val =
if invalid val then
pure msg
else
mempty

claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)

-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- JSON object of JWT claims.
parseToken :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JWT.JwtContent
parseToken _ "" = throwError $ JwtErr $ JwtDecodeErr EmptyAuthHeader
parseToken secret tkn = do
Expand All @@ -116,36 +142,3 @@ parseToken secret tkn = do
jwtDecodeError JWT.BadCrypto = JwtDecodeErr BadCrypto
-- Control never reaches here, the decode function only returns the above three
jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError

parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult
parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do
validateClaims time (audMatchesCfg cfg) mclaims
-- role defaults to anon if not specified in jwt
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
pure AuthResult
{ authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role)
, authRole = role
}
where
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
walkJSPath _ _ = Nothing

findFirstMatch matchWith pattern = foldr checkMatch Nothing
where
checkMatch (JSON.String txt) acc
| pattern `matchWith` txt = Just $ JSON.String txt
| otherwise = acc
checkMatch _ acc = acc

unquoted :: JSON.Value -> BS.ByteString
unquoted (JSON.String t) = encodeUtf8 t
unquoted v = LBS.toStrict $ JSON.encode v
Loading