Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions src/PostgREST/Auth/Jwt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)

import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath,
JSPathExp (..))
JSPathExp (..), audMatchesCfg)
import PostgREST.Error (Error (..),
JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed),
JwtDecodeError (..), JwtError (..))
Expand All @@ -52,21 +52,20 @@ decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object
decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims)
decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType

validateClaims :: MonadError Error m => UTCTime -> Maybe Text -> JSON.Object -> m ()
validateClaims time getConfigAud claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time getConfigAud claims)
validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m ()
validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims)

data ValidAud = VANull | VAString Text | VAArray [Text] deriving Generic
data ValidAud = VAString Text | VAArray [Text] deriving Generic
instance JSON.FromJSON ValidAud where
parseJSON JSON.Null = pure VANull
parseJSON o = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue } o
parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue }

checkForErrors :: (Monad m, forall a. Monoid (m a)) => UTCTime -> Maybe Text -> JSON.Object -> m JwtClaimsError
checkForErrors time cfgAud = mconcat
checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError
checkForErrors time audMatches = mconcat
[
claim "exp" ExpClaimNotNumber $ inThePast JWTExpired
, claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid
, claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture
, claim "aud" AudClaimNotStringOrArray checkAud
, claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
]
where
allowedSkewSeconds = 30 :: Int64
Expand All @@ -79,12 +78,9 @@ checkForErrors time cfgAud = mconcat

checkTime cond = checkValue (cond. sciToInt)

checkAud = \case
(VAString aud) -> liftMaybe cfgAud >>= checkValue (aud /=) JWTNotInAudience
(VAArray auds) | (not . null) auds -> liftMaybe cfgAud >>= checkValue (not . (`elem` auds)) JWTNotInAudience
_ -> mempty

liftMaybe = maybe mempty pure
validAud = \case
(VAString aud) -> audMatches aud
(VAArray auds) -> null auds || any audMatches auds

checkValue invalid msg val =
if invalid val then
Expand Down Expand Up @@ -122,8 +118,8 @@ parseToken secret tkn = do
jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError

parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult
parseClaims AppConfig{configJwtAudience, configJwtRoleClaimKey, configDbAnonRole} time mclaims = do
validateClaims time configJwtAudience mclaims
parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do
validateClaims time (audMatchesCfg cfg) mclaims
-- role defaults to anon if not specified in jwt
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
Expand Down
3 changes: 3 additions & 0 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module PostgREST.Config
, addFallbackAppName
, addTargetSessionAttrs
, exampleConfigFile
, audMatchesCfg
) where

import qualified Data.Aeson as JSON
Expand Down Expand Up @@ -67,6 +68,8 @@ import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier, dumpQi,

import Protolude hiding (Proxy, toList)

audMatchesCfg :: AppConfig -> Text -> Bool
audMatchesCfg = maybe (const True) (==) . configJwtAudience

data AppConfig = AppConfig
{ configAppSettings :: [(Text, Text)]
Expand Down