@@ -4,68 +4,108 @@ Description : PostgREST JWT support functions.
44
55This module provides functions to deal with JWT parsing and validation (http://jwt.io).
66-}
7+ {-# LANGUAGE DataKinds #-}
78{-# LANGUAGE DeriveGeneric #-}
89{-# LANGUAGE FlexibleContexts #-}
9- {-# LANGUAGE ImpredicativeTypes #-}
1010{-# LANGUAGE LambdaCase #-}
11- {-# LANGUAGE NamedFieldPuns #-}
12- {-# LANGUAGE QuantifiedConstraints #-}
11+ {-# LANGUAGE MultiParamTypeClasses #-}
12+ {-# LANGUAGE PolyKinds #-}
13+ {-# LANGUAGE TypeFamilies #-}
14+ {-# LANGUAGE TypeOperators #-}
1315
1416module PostgREST.Auth.Jwt
15- ( parseAndDecodeClaims
16- , parseClaims ) where
17-
18- import qualified Data.Aeson as JSON
19- import qualified Data.Aeson.Key as K
20- import qualified Data.Aeson.KeyMap as KM
21- import qualified Data.ByteString as BS
22- import qualified Data.ByteString.Internal as BS
23- import qualified Data.ByteString.Lazy.Char8 as LBS
24- import qualified Data.Scientific as Sci
25- import qualified Data.Text as T
26- import qualified Data.Vector as V
27- import qualified Jose.Jwk as JWT
28- import qualified Jose.Jwt as JWT
17+ ( Validation ( .. )
18+ , Validated ( getValidated )
19+ , parseAndDecodeClaims
20+ , validateAud
21+ , validateTimeClaims
22+ , (>>>) ) where
23+
24+ import qualified Data.Aeson as JSON
25+ import qualified Data.Aeson.KeyMap as KM
26+ import qualified Data.ByteString as BS
27+ import qualified Data.ByteString.Internal as BS
28+ import qualified Data.Scientific as Sci
29+ import qualified Jose.Jwk as JWT
30+ import qualified Jose.Jwt as JWT
2931
3032import Control.Monad.Except (liftEither )
3133import Data.Either.Combinators (mapLeft )
3234import Data.Text ()
3335import Data.Time.Clock (UTCTime , nominalDiffTimeToSeconds )
3436import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds )
3537
36- import PostgREST.Auth.Types (AuthResult (.. ))
37- import PostgREST.Config (AppConfig (.. ), FilterExp (.. ), JSPath ,
38- JSPathExp (.. ), audMatchesCfg )
39- import PostgREST.Error (Error (.. ),
40- JwtClaimsError (AudClaimNotStringOrArray , ExpClaimNotNumber , IatClaimNotNumber , JWTExpired , JWTIssuedAtFuture , JWTNotInAudience , JWTNotYetValid , NbfClaimNotNumber , ParsingClaimsFailed ),
41- JwtDecodeError (.. ), JwtError (.. ))
38+ import PostgREST.Error (Error (.. ),
39+ JwtClaimsError (AudClaimNotStringOrArray , ExpClaimNotNumber , IatClaimNotNumber , JWTExpired , JWTIssuedAtFuture , JWTNotInAudience , JWTNotYetValid , NbfClaimNotNumber , ParsingClaimsFailed ),
40+ JwtDecodeError (.. ), JwtError (.. ))
4241
4342import Data.Aeson ((.:?) )
4443import Data.Aeson.Types (parseMaybe )
44+ import Data.Coerce (coerce )
4545import Jose.Jwk (JwkSet )
4646import Protolude hiding (first )
4747
48+ -- A value tagged by a type-level list of validations pefrormed on it
49+ newtype Validated (k :: [v ]) a = Validated { getValidated :: a }
50+
51+ -- Helper to implement type safe validation chaining
52+ type family (++ ) (lst :: [k ]) lst' where
53+ '[] ++ lst = lst
54+ (l : ls ) ++ lst = l : (ls ++ lst )
55+
56+ -- Validation chaining operator
57+ (>>>) :: (Monad m , Coercible (m (Validated kc c )) (m (Validated (kb ++ kc ) c )))
58+ => (a -> m (Validated kb b ))
59+ -> (b -> m (Validated kc c ))
60+ -> a
61+ -> m (Validated (kb ++ kc ) c )
62+ f >>> g = coerce . (f >=> g . coerce)
63+
4864parseAndDecodeClaims :: (MonadError Error m , MonadIO m ) => JwkSet -> ByteString -> m JSON. Object
49- parseAndDecodeClaims jwkSet token = parseToken jwkSet token >>= decodeClaims
65+ parseAndDecodeClaims jwkSet = parseToken jwkSet >=> decodeClaims
66+
67+ data Validation = Aud | Time
68+
69+ validateAud :: MonadError Error m => (Text -> Bool ) -> JSON. Object -> m (Validated '[Aud ] JSON. Object )
70+ validateAud = validate . checkAud
71+
72+ validateTimeClaims :: MonadError Error m => UTCTime -> JSON. Object -> m (Validated '[Time ] JSON. Object )
73+ validateTimeClaims = validate . checkExpNbfIat
5074
5175decodeClaims :: MonadError Error m => JWT. JwtContent -> m JSON. Object
5276decodeClaims (JWT. Jws (_, claims)) = maybe (throwError (JwtErr $ JwtClaimsErr ParsingClaimsFailed )) pure (JSON. decodeStrict claims)
5377decodeClaims _ = throwError $ JwtErr $ JwtDecodeErr UnsupportedTokenType
5478
55- validateClaims :: MonadError Error m => UTCTime -> ( Text -> Bool ) -> JSON. Object -> m ()
56- validateClaims time audMatches claims = liftEither $ maybeToLeft () ( fmap JwtErr . getAlt $ JwtClaimsErr <$> checkForErrors time audMatches claims)
79+ validate :: MonadError Error m => ( t -> Alt Maybe JwtClaimsError ) -> t -> m (Validated k t )
80+ validate f claims = fmap Validated $ liftEither $ maybeToLeft claims $ fmap JwtErr . getAlt $ JwtClaimsErr <$> f claims
5781
5882data ValidAud = VAString Text | VAArray [Text ] deriving Generic
5983instance JSON. FromJSON ValidAud where
6084 parseJSON = JSON. genericParseJSON JSON. defaultOptions { JSON. sumEncoding = JSON. UntaggedValue }
6185
62- checkForErrors :: (Applicative m , Monoid (m JwtClaimsError )) => UTCTime -> (Text -> Bool ) -> JSON. Object -> m JwtClaimsError
63- checkForErrors time audMatches = mconcat
86+ claim :: (JSON. FromJSON a , Applicative f , Monoid (f p )) => KM. Key -> p -> (a -> f p ) -> JSON. Object -> f p
87+ claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)
88+
89+ checkValue :: (Applicative f , Monoid (f p )) => (t -> Bool ) -> p -> t -> f p
90+ checkValue invalid msg val =
91+ if invalid val then
92+ pure msg
93+ else
94+ mempty
95+
96+ checkAud :: (Applicative f , Monoid (f JwtClaimsError )) => (Text -> Bool ) -> JSON. Object -> f JwtClaimsError
97+ checkAud audMatches = claim " aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
98+ where
99+ validAud = \ case
100+ (VAString aud) -> audMatches aud
101+ (VAArray auds) -> null auds || any audMatches auds
102+
103+ checkExpNbfIat :: (Applicative m , Monoid (m JwtClaimsError )) => UTCTime -> JSON. Object -> m JwtClaimsError
104+ checkExpNbfIat time = mconcat
64105 [
65106 claim " exp" ExpClaimNotNumber $ inThePast JWTExpired
66107 , claim " nbf" NbfClaimNotNumber $ inTheFuture JWTNotYetValid
67108 , claim " iat" IatClaimNotNumber $ inTheFuture JWTIssuedAtFuture
68- , claim " aud" AudClaimNotStringOrArray $ checkValue (not . validAud) JWTNotInAudience
69109 ]
70110 where
71111 allowedSkewSeconds = 30 :: Int64
@@ -78,20 +118,6 @@ checkForErrors time audMatches = mconcat
78118
79119 checkTime cond = checkValue (cond. sciToInt)
80120
81- validAud = \ case
82- (VAString aud) -> audMatches aud
83- (VAArray auds) -> null auds || any audMatches auds
84-
85- checkValue invalid msg val =
86- if invalid val then
87- pure msg
88- else
89- mempty
90-
91- claim key parseError checkParsed = maybe (pure parseError) (maybe mempty checkParsed) . parseMaybe (.:? key)
92-
93- -- | Receives the JWT secret and audience (from config) and a JWT and returns a
94- -- JSON object of JWT claims.
95121parseToken :: (MonadError Error m , MonadIO m ) => JwkSet -> ByteString -> m JWT. JwtContent
96122parseToken _ " " = throwError $ JwtErr $ JwtDecodeErr EmptyAuthHeader
97123parseToken secret tkn = do
@@ -116,36 +142,3 @@ parseToken secret tkn = do
116142 jwtDecodeError JWT. BadCrypto = JwtDecodeErr BadCrypto
117143 -- Control never reaches here, the decode function only returns the above three
118144 jwtDecodeError _ = JwtDecodeErr UnreachableDecodeError
119-
120- parseClaims :: (MonadError Error m , MonadIO m ) => AppConfig -> UTCTime -> JSON. Object -> m AuthResult
121- parseClaims cfg@ AppConfig {configJwtRoleClaimKey, configDbAnonRole} time mclaims = do
122- validateClaims time (audMatchesCfg cfg) mclaims
123- -- role defaults to anon if not specified in jwt
124- role <- liftEither . maybeToRight (JwtErr JwtTokenRequired ) $
125- unquoted <$> walkJSPath (Just $ JSON. Object mclaims) configJwtRoleClaimKey <|> configDbAnonRole
126- pure AuthResult
127- { authClaims = mclaims & KM. insert " role" (JSON. toJSON $ decodeUtf8 role)
128- , authRole = role
129- }
130- where
131- walkJSPath :: Maybe JSON. Value -> JSPath -> Maybe JSON. Value
132- walkJSPath x [] = x
133- walkJSPath (Just (JSON. Object o)) (JSPKey key: rest) = walkJSPath (KM. lookup (K. fromText key) o) rest
134- walkJSPath (Just (JSON. Array ar)) (JSPIdx idx: rest) = walkJSPath (ar V. !? idx) rest
135- walkJSPath (Just (JSON. Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
136- walkJSPath (Just (JSON. Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
137- walkJSPath (Just (JSON. Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T. isPrefixOf txt ar
138- walkJSPath (Just (JSON. Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T. isSuffixOf txt ar
139- walkJSPath (Just (JSON. Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T. isInfixOf txt ar
140- walkJSPath _ _ = Nothing
141-
142- findFirstMatch matchWith pattern = foldr checkMatch Nothing
143- where
144- checkMatch (JSON. String txt) acc
145- | pattern `matchWith` txt = Just $ JSON. String txt
146- | otherwise = acc
147- checkMatch _ acc = acc
148-
149- unquoted :: JSON. Value -> BS. ByteString
150- unquoted (JSON. String t) = encodeUtf8 t
151- unquoted v = LBS. toStrict $ JSON. encode v
0 commit comments