Skip to content

Commit 582b83d

Browse files
committed
refactor: JWT code reorganization to facilitate more granular caching decisions
JWT cache implementation introduced two new modules: Auth.Jwt and Auth.JwtCache. This refactoring reorganizes code in Auth and the above two modules so that reponsibilities and dependencies are more clear: * parseClaims function was moved from Auth.Jwt back to Auth. Thanks to it Auth.Jwt module became independent from AuthResult data structure and role handling. Its only purpose right now is to parse/verify tokens and validate claims * validateClaims function in Auth.Jwt module was split to separate validateAud and validateTimeClaims functions. This change was necessary to allow Auth.JwtCache module to be the only place to decide what validations are cached. * Introduced type level tagging of claim validation results so that it is possible to statically ensure all required validations were performed (see Auth.JwtCache.parseAndValidateClaims signature) * Made Auth.Jwt module independent from Config module: validateAud no longer takes Config as an argument but a (Text -> Bool) function to validate audience values * Auth.JwtCache module was changed so that it is now possilble to cache claims validation results. Tagged claim validation result types are used to ensure all validations are performed regardless of the decision about what should be cached. * JwtCache datatype in Auth.JwtCache module was renamed to CacheState with JwksNotConfigured, NotCaching and Caching constructors. * Creation of a Sieve cache instance was moved to a CacheVariant typeclass function newCache * NeedsReconfiguration typeclass was introduced to handle differences between different CacheVariants in deciding when cache reset is needed (if aud claim validation results are cached we need to reset cache when jwt-aud changes)
1 parent eb908c6 commit 582b83d

File tree

3 files changed

+197
-139
lines changed

3 files changed

+197
-139
lines changed

src/PostgREST/Auth.hs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
{-# LANGUAGE RecordWildCards #-}
1+
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE NamedFieldPuns #-}
3+
{-# LANGUAGE RecordWildCards #-}
24
{-|
35
Module : PostgREST.Auth
46
Description : PostgREST authentication functions.
@@ -30,13 +32,19 @@ import System.TimeIt (timeItT)
3032

3133
import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
3234
getTime)
33-
import PostgREST.Auth.Jwt (parseClaims)
3435
import PostgREST.Auth.JwtCache (lookupJwtCache)
3536
import PostgREST.Auth.Types (AuthResult (..))
36-
import PostgREST.Config (AppConfig (..))
37-
import PostgREST.Error (Error (..))
37+
import PostgREST.Config (AppConfig (..), FilterExp (..),
38+
JSPath, JSPathExp (..))
39+
import PostgREST.Error (Error (..), JwtError (..))
3840

39-
import Protolude
41+
import Control.Monad.Except (liftEither)
42+
import qualified Data.Aeson as JSON
43+
import qualified Data.Aeson.Key as K
44+
import qualified Data.Aeson.KeyMap as KM
45+
import qualified Data.Text as T
46+
import qualified Data.Vector as V
47+
import Protolude
4048

4149
-- | Validate authorization header
4250
-- Parse and store JWT claims for future use in the request.
@@ -46,7 +54,8 @@ middleware appState app req respond = do
4654
time <- getTime appState
4755

4856
let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
49-
parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time
57+
parseToken = maybe (pure KM.empty) (lookupJwtCache jwtCacheState time)
58+
parseJwt = runExceptT $ parseToken >=> parseClaims conf $ token
5059
jwtCacheState = getJwtCacheState appState
5160

5261
-- If ServerTimingEnabled -> calculate JWT validation time
@@ -59,6 +68,38 @@ middleware appState app req respond = do
5968

6069
app req' respond
6170

71+
parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> JSON.Object -> m AuthResult
72+
parseClaims AppConfig{configJwtRoleClaimKey, configDbAnonRole} mclaims = do
73+
-- role defaults to anon if not specified in jwt
74+
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
75+
unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
76+
pure AuthResult
77+
{ authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role)
78+
, authRole = role
79+
}
80+
where
81+
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
82+
walkJSPath x [] = x
83+
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
84+
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
85+
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
86+
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
87+
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
88+
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
89+
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
90+
walkJSPath _ _ = Nothing
91+
92+
findFirstMatch matchWith pattern = foldr checkMatch Nothing
93+
where
94+
checkMatch (JSON.String txt) acc
95+
| pattern `matchWith` txt = Just $ JSON.String txt
96+
| otherwise = acc
97+
checkMatch _ acc = acc
98+
99+
unquoted :: JSON.Value -> BS.ByteString
100+
unquoted (JSON.String t) = encodeUtf8 t
101+
unquoted v = BS.toStrict $ JSON.encode v
102+
62103
authResultKey :: Vault.Key (Either Error AuthResult)
63104
authResultKey = unsafePerformIO Vault.newKey
64105
{-# NOINLINE authResultKey #-}

src/PostgREST/Auth/Jwt.hs

Lines changed: 69 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,108 @@ Description : PostgREST JWT support functions.
44
55
This module provides functions to deal with JWT parsing and validation (http://jwt.io).
66
-}
7+
{-# LANGUAGE DataKinds #-}
78
{-# LANGUAGE DeriveGeneric #-}
89
{-# LANGUAGE FlexibleContexts #-}
9-
{-# LANGUAGE ImpredicativeTypes #-}
1010
{-# LANGUAGE LambdaCase #-}
11-
{-# LANGUAGE NamedFieldPuns #-}
12-
{-# LANGUAGE QuantifiedConstraints #-}
11+
{-# LANGUAGE MultiParamTypeClasses #-}
12+
{-# LANGUAGE PolyKinds #-}
13+
{-# LANGUAGE TypeFamilies #-}
14+
{-# LANGUAGE TypeOperators #-}
1315

1416
module PostgREST.Auth.Jwt
15-
( parseAndDecodeClaims
16-
, parseClaims) where
17-
18-
import qualified Data.Aeson as JSON
19-
import qualified Data.Aeson.Key as K
20-
import qualified Data.Aeson.KeyMap as KM
21-
import qualified Data.ByteString as BS
22-
import qualified Data.ByteString.Internal as BS
23-
import qualified Data.ByteString.Lazy.Char8 as LBS
24-
import qualified Data.Scientific as Sci
25-
import qualified Data.Text as T
26-
import qualified Data.Vector as V
27-
import qualified Jose.Jwk as JWT
28-
import qualified Jose.Jwt as JWT
17+
( Validation (..)
18+
, Validated (getValidated)
19+
, parseAndDecodeClaims
20+
, validateAud
21+
, validateTimeClaims
22+
, (>>>)) where
23+
24+
import qualified Data.Aeson as JSON
25+
import qualified Data.Aeson.KeyMap as KM
26+
import qualified Data.ByteString as BS
27+
import qualified Data.ByteString.Internal as BS
28+
import qualified Data.Scientific as Sci
29+
import qualified Jose.Jwk as JWT
30+
import qualified Jose.Jwt as JWT
2931

3032
import Control.Monad.Except (liftEither)
3133
import Data.Either.Combinators (mapLeft)
3234
import Data.Text ()
3335
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
3436
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
3537

36-
import PostgREST.Auth.Types (AuthResult (..))
37-
import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath,
38-
JSPathExp (..), audMatchesCfg)
39-
import PostgREST.Error (Error (..),
40-
JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed),
41-
JwtDecodeError (..), JwtError (..))
38+
import PostgREST.Error (Error (..),
39+
JwtClaimsError (AudClaimNotStringOrArray, ExpClaimNotNumber, IatClaimNotNumber, JWTExpired, JWTIssuedAtFuture, JWTNotInAudience, JWTNotYetValid, NbfClaimNotNumber, ParsingClaimsFailed),
40+
JwtDecodeError (..), JwtError (..))
4241

4342
import Data.Aeson ((.:?))
4443
import Data.Aeson.Types (parseMaybe)
44+
import Data.Coerce (coerce)
4545
import Jose.Jwk (JwkSet)
4646
import Protolude hiding (first)
4747

48+
-- A value tagged by a type-level list of validations pefrormed on it
49+
newtype Validated (k :: [v]) a = Validated { getValidated :: a }
50+
51+
-- Helper to implement type safe validation chaining
52+
type family (++) (lst::[k]) lst' where
53+
'[] ++ lst = lst
54+
(l : ls) ++ lst = l : (ls ++ lst)
55+
56+
-- Validation chaining operator
57+
(>>>) :: (Monad m, Coercible (m (Validated kc c)) (m (Validated (kb ++ kc) c)))
58+
=> (a -> m (Validated kb b))
59+
-> (b -> m (Validated kc c))
60+
-> a
61+
-> m (Validated (kb ++ kc) c)
62+
f >>> g = coerce . (f >=> g . coerce)
63+
4864
parseAndDecodeClaims :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JSON.Object
49-
parseAndDecodeClaims jwkSet token = parseToken jwkSet token >>= decodeClaims
65+
parseAndDecodeClaims jwkSet = parseToken jwkSet >=> decodeClaims
66+
67+
data Validation = Aud | Time
68+
69+
validateAud :: MonadError Error m => (Text -> Bool) -> JSON.Object -> m (Validated '[Aud] JSON.Object)
70+
validateAud = validate . checkAud
71+
72+
validateTimeClaims :: MonadError Error m => UTCTime -> JSON.Object -> m (Validated '[Time] JSON.Object)
73+
validateTimeClaims = validate . checkExpNbfIat
5074

5175
decodeClaims :: MonadError Error m => JWT.JwtContent -> m JSON.Object
5276
decodeClaims (JWT.Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed)) pure (JSON.decodeStrict claims)
5377
decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType
5478

55-
validateClaims :: MonadError Error m => UTCTime -> (Text -> Bool) -> JSON.Object -> m ()
56-
validateClaims time audMatches claims = liftEither $ maybeToLeft () (fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims)
79+
validate :: MonadError Error m => (t -> Alt Maybe JwtClaimsError) -> t -> m (Validated k t)
80+
validate f claims = fmap Validated $ liftEither $ maybeToLeft claims $ fmap JwtErr . getAlt $ JwtClaimsErr <$> f claims
5781

5882
data ValidAud = VAString Text | VAArray [Text] deriving Generic
5983
instance JSON.FromJSON ValidAud where
6084
parseJSON = JSON.genericParseJSON JSON.defaultOptions { JSON.sumEncoding = JSON.UntaggedValue }
6185

62-
checkForErrors :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> (Text -> Bool) -> JSON.Object -> m JwtClaimsError
63-
checkForErrors time audMatches = mconcat
86+
claim :: (JSON.FromJSON a, Applicative f, Monoid (f p)) => KM.Key -> p -> (a -> f p) -> JSON.Object -> f p
87+
claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)
88+
89+
checkValue :: (Applicative f, Monoid (f p)) => (t -> Bool) -> p -> t -> f p
90+
checkValue invalid msg val =
91+
if invalid val then
92+
pure msg
93+
else
94+
mempty
95+
96+
checkAud :: (Applicative f, Monoid (f JwtClaimsError)) => (Text -> Bool) -> JSON.Object -> f JwtClaimsError
97+
checkAud audMatches = claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
98+
where
99+
validAud = \case
100+
(VAString aud) -> audMatches aud
101+
(VAArray auds) -> null auds || any audMatches auds
102+
103+
checkExpNbfIat :: (Applicative m, Monoid (m JwtClaimsError)) => UTCTime -> JSON.Object -> m JwtClaimsError
104+
checkExpNbfIat time = mconcat
64105
[
65106
claim "exp" ExpClaimNotNumber $ inThePast JWTExpired
66107
, claim "nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid
67108
, claim "iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture
68-
, claim "aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
69109
]
70110
where
71111
allowedSkewSeconds = 30 :: Int64
@@ -78,20 +118,6 @@ checkForErrors time audMatches = mconcat
78118

79119
checkTime cond = checkValue (cond. sciToInt)
80120

81-
validAud = \case
82-
(VAString aud) -> audMatches aud
83-
(VAArray auds) -> null auds || any audMatches auds
84-
85-
checkValue invalid msg val =
86-
if invalid val then
87-
pure msg
88-
else
89-
mempty
90-
91-
claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)
92-
93-
-- | Receives the JWT secret and audience (from config) and a JWT and returns a
94-
-- JSON object of JWT claims.
95121
parseToken :: (MonadError Error m, MonadIO m) => JwkSet -> ByteString -> m JWT.JwtContent
96122
parseToken _ "" = throwError $ JwtErr $ JwtDecodeErr EmptyAuthHeader
97123
parseToken secret tkn = do
@@ -116,36 +142,3 @@ parseToken secret tkn = do
116142
jwtDecodeError JWT.BadCrypto = JwtDecodeErr BadCrypto
117143
-- Control never reaches here, the decode function only returns the above three
118144
jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError
119-
120-
parseClaims :: (MonadError Error m, MonadIO m) => AppConfig -> UTCTime -> JSON.Object -> m AuthResult
121-
parseClaims cfg@AppConfig{configJwtRoleClaimKey, configDbAnonRole} time mclaims = do
122-
validateClaims time (audMatchesCfg cfg) mclaims
123-
-- role defaults to anon if not specified in jwt
124-
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
125-
unquoted <$> walkJSPath (Just $ JSON.Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
126-
pure AuthResult
127-
{ authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role)
128-
, authRole = role
129-
}
130-
where
131-
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
132-
walkJSPath x [] = x
133-
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
134-
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
135-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
136-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
137-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
138-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
139-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
140-
walkJSPath _ _ = Nothing
141-
142-
findFirstMatch matchWith pattern = foldr checkMatch Nothing
143-
where
144-
checkMatch (JSON.String txt) acc
145-
| pattern `matchWith` txt = Just $ JSON.String txt
146-
| otherwise = acc
147-
checkMatch _ acc = acc
148-
149-
unquoted :: JSON.Value -> BS.ByteString
150-
unquoted (JSON.String t) = encodeUtf8 t
151-
unquoted v = LBS.toStrict $ JSON.encode v

0 commit comments

Comments
 (0)