@@ -47,7 +47,8 @@ type JSONWebKey struct {
4747 X5c []string `json:"x5c,omitempty"` // Can be used as fallback or primary source
4848}
4949
50- type JWTValidator struct {
50+ type JWTValidator [T jwt.Claims ] struct {
51+ newClaims func () T
5152 JWKSFetcher * JWKSFetcher
5253 audiences []string
5354 validMethods []string
@@ -56,15 +57,31 @@ type JWTValidator struct {
5657 keyFunc jwt.Keyfunc
5758}
5859
59- // NewJWTValidator creates a new JWTValidator struct .
60+ // NewJWTValidator creates a new JWTValidator.
6061//
6162// Empty audience, issuer or validMethods results in all tokens being rejected.
62- func NewJWTValidator (fetcher * JWKSFetcher , validIssuer string , audiences , validMethods []string ) (* JWTValidator , error ) {
63+ func NewJWTValidator (fetcher * JWKSFetcher , validIssuer string , audiences , validMethods []string ) (* JWTValidator [* UserClaims ], error ) {
64+ return NewJWTValidatorWithClaims (fetcher , validIssuer , audiences , validMethods , func () * UserClaims { return & UserClaims {} })
65+ }
66+
67+ // NewJWTValidatorWithClaims creates a new JWTValidator struct with custom claims.
68+ //
69+ // Type will be infered from user provided claims constuctor.
70+ // T most be a pointer type for JSON deconding to work.
71+ // Empty audience or validMethods results in all tokens being rejected.
72+ func NewJWTValidatorWithClaims [T jwt.Claims ](
73+ fetcher * JWKSFetcher ,
74+ validIssuer string ,
75+ audiences []string ,
76+ validMethods []string ,
77+ newClaims func () T ,
78+ ) (* JWTValidator [T ], error ) {
6379 if len (validIssuer ) == 0 {
6480 return nil , fmt .Errorf ("issuer not configured" )
6581 }
6682
67- v := & JWTValidator {
83+ v := & JWTValidator [T ]{
84+ newClaims : newClaims ,
6885 JWKSFetcher : fetcher ,
6986 audiences : audiences ,
7087 validMethods : validMethods ,
@@ -79,7 +96,7 @@ func NewJWTValidator(fetcher *JWKSFetcher, validIssuer string, audiences, validM
7996// JWTMiddleware takes a JWTValidator and return a function.
8097// The returned function takes in and returns a http.Handler.
8198// The returned http.HandlerFunc is the actual middleware.
82- func JWTMiddleware (validator * JWTValidator ) func (http.Handler ) http.Handler {
99+ func JWTMiddleware [ T jwt. Claims ] (validator * JWTValidator [ T ] ) func (http.Handler ) http.Handler {
83100 return func (next http.Handler ) http.Handler {
84101 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
85102 authHeader := r .Header .Get ("Authorization" )
@@ -156,7 +173,7 @@ func parseKey(jwk *JSONWebKey) (interface{}, error) {
156173// A key lookup function accepts a parsed JWT token and returns the corresponding public key
157174// that was used to sign it, if any is found.
158175// Also validates that the key is not an encryption key.
159- func (v * JWTValidator ) createKeyFunc () jwt.Keyfunc {
176+ func (v * JWTValidator [ T ] ) createKeyFunc () jwt.Keyfunc {
160177 return func (token * jwt.Token ) (interface {}, error ) {
161178 kid , ok := token .Header ["kid" ].(string )
162179 if ! ok {
@@ -200,15 +217,20 @@ func (v *JWTValidator) createKeyFunc() jwt.Keyfunc {
200217// ValidateJWT uses a JWTValidator to validate any standalone JWT.
201218// Accepts a JWT string and returns any claims specified in the UserClaims struct.
202219// Returns claims even if there is an error parsing.
203- func (v * JWTValidator ) ValidateJWT (ctx context.Context , tokenStr string ) (* UserClaims , error ) {
204- claims := & UserClaims {}
220+ func (v * JWTValidator [T ]) ValidateJWT (ctx context.Context , tokenStr string ) (T , error ) {
221+
222+ claims := v .newClaims ()
205223 // Parse and validate token.
206224 token , err := jwt .ParseWithClaims (tokenStr , claims , v .keyFunc ,
207225 jwt .WithValidMethods (v .validMethods ),
208226 jwt .WithIssuer (v .validIssuer ))
209227 if err != nil {
228+ var issuer string
229+ if iss , err := claims .GetIssuer (); err == nil {
230+ issuer = iss
231+ }
210232 msg := "failed to parse jwt token with claims"
211- v .logger .ErrorContext (ctx , msg , "error" , err , "iss" , claims . Issuer , "valid iss" , v .validIssuer )
233+ v .logger .ErrorContext (ctx , msg , "error" , err , "iss" , issuer , "valid iss" , v .validIssuer )
212234 return claims , fmt .Errorf ("%s: %w" , msg , err )
213235 }
214236
@@ -218,9 +240,15 @@ func (v *JWTValidator) ValidateJWT(ctx context.Context, tokenStr string) (*UserC
218240 return claims , ErrInvalidToken
219241 }
220242
243+ aud , err := claims .GetAudience ()
244+ if err != nil {
245+ v .logger .ErrorContext (ctx , "failed to get audience from token" )
246+ return claims , ErrInvalidAud
247+ }
248+
221249 // Check for valid audience
222- if ! isAudienceValid (claims . Audience , v .audiences ) {
223- v .logger .ErrorContext (ctx , "token audience validation failed" , "audiences" , claims . Audience )
250+ if ! isAudienceValid (aud , v .audiences ) {
251+ v .logger .ErrorContext (ctx , "token audience validation failed" , "audiences" , aud )
224252 return claims , ErrInvalidAud
225253 }
226254
0 commit comments