Skip to content

Commit b57f007

Browse files
committed
feat: add validator for generic custom claims
1 parent aaa45e4 commit b57f007

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

validate.go

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ type JSONWebKey struct {
4747
X5c []string `json:"x5c,omitempty"` // Can be used as fallback or primary source
4848
}
4949

50-
type JWTValidator struct {
50+
type JWTValidator[T jwt.Claims] struct {
51+
newClaims func() T
5152
JWKSFetcher *JWKSFetcher
5253
audiences []string
5354
validMethods []string
@@ -56,15 +57,31 @@ type JWTValidator struct {
5657
keyFunc jwt.Keyfunc
5758
}
5859

59-
// NewJWTValidator creates a new JWTValidator struct.
60+
// NewJWTValidator creates a new JWTValidator.
6061
//
6162
// Empty audience, issuer or validMethods results in all tokens being rejected.
62-
func NewJWTValidator(fetcher *JWKSFetcher, validIssuer string, audiences, validMethods []string) (*JWTValidator, error) {
63+
func NewJWTValidator(fetcher *JWKSFetcher, validIssuer string, audiences, validMethods []string) (*JWTValidator[*UserClaims], error) {
64+
return NewJWTValidatorWithClaims(fetcher, validIssuer, audiences, validMethods, func() *UserClaims { return &UserClaims{} })
65+
}
66+
67+
// NewJWTValidatorWithClaims creates a new JWTValidator struct with custom claims.
68+
//
69+
// Type will be infered from user provided claims constuctor.
70+
// T most be a pointer type for JSON deconding to work.
71+
// Empty audience or validMethods results in all tokens being rejected.
72+
func NewJWTValidatorWithClaims[T jwt.Claims](
73+
fetcher *JWKSFetcher,
74+
validIssuer string,
75+
audiences []string,
76+
validMethods []string,
77+
newClaims func() T,
78+
) (*JWTValidator[T], error) {
6379
if len(validIssuer) == 0 {
6480
return nil, fmt.Errorf("issuer not configured")
6581
}
6682

67-
v := &JWTValidator{
83+
v := &JWTValidator[T]{
84+
newClaims: newClaims,
6885
JWKSFetcher: fetcher,
6986
audiences: audiences,
7087
validMethods: validMethods,
@@ -79,7 +96,7 @@ func NewJWTValidator(fetcher *JWKSFetcher, validIssuer string, audiences, validM
7996
// JWTMiddleware takes a JWTValidator and return a function.
8097
// The returned function takes in and returns a http.Handler.
8198
// The returned http.HandlerFunc is the actual middleware.
82-
func JWTMiddleware(validator *JWTValidator) func(http.Handler) http.Handler {
99+
func JWTMiddleware[T jwt.Claims](validator *JWTValidator[T]) func(http.Handler) http.Handler {
83100
return func(next http.Handler) http.Handler {
84101
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
85102
authHeader := r.Header.Get("Authorization")
@@ -156,7 +173,7 @@ func parseKey(jwk *JSONWebKey) (interface{}, error) {
156173
// A key lookup function accepts a parsed JWT token and returns the corresponding public key
157174
// that was used to sign it, if any is found.
158175
// Also validates that the key is not an encryption key.
159-
func (v *JWTValidator) createKeyFunc() jwt.Keyfunc {
176+
func (v *JWTValidator[T]) createKeyFunc() jwt.Keyfunc {
160177
return func(token *jwt.Token) (interface{}, error) {
161178
kid, ok := token.Header["kid"].(string)
162179
if !ok {
@@ -200,15 +217,20 @@ func (v *JWTValidator) createKeyFunc() jwt.Keyfunc {
200217
// ValidateJWT uses a JWTValidator to validate any standalone JWT.
201218
// Accepts a JWT string and returns any claims specified in the UserClaims struct.
202219
// Returns claims even if there is an error parsing.
203-
func (v *JWTValidator) ValidateJWT(ctx context.Context, tokenStr string) (*UserClaims, error) {
204-
claims := &UserClaims{}
220+
func (v *JWTValidator[T]) ValidateJWT(ctx context.Context, tokenStr string) (T, error) {
221+
222+
claims := v.newClaims()
205223
// Parse and validate token.
206224
token, err := jwt.ParseWithClaims(tokenStr, claims, v.keyFunc,
207225
jwt.WithValidMethods(v.validMethods),
208226
jwt.WithIssuer(v.validIssuer))
209227
if err != nil {
228+
var issuer string
229+
if iss, err := claims.GetIssuer(); err == nil {
230+
issuer = iss
231+
}
210232
msg := "failed to parse jwt token with claims"
211-
v.logger.ErrorContext(ctx, msg, "error", err, "iss", claims.Issuer, "valid iss", v.validIssuer)
233+
v.logger.ErrorContext(ctx, msg, "error", err, "iss", issuer, "valid iss", v.validIssuer)
212234
return claims, fmt.Errorf("%s: %w", msg, err)
213235
}
214236

@@ -218,9 +240,15 @@ func (v *JWTValidator) ValidateJWT(ctx context.Context, tokenStr string) (*UserC
218240
return claims, ErrInvalidToken
219241
}
220242

243+
aud, err := claims.GetAudience()
244+
if err != nil {
245+
v.logger.ErrorContext(ctx, "failed to get audience from token")
246+
return claims, ErrInvalidAud
247+
}
248+
221249
// Check for valid audience
222-
if !isAudienceValid(claims.Audience, v.audiences) {
223-
v.logger.ErrorContext(ctx, "token audience validation failed", "audiences", claims.Audience)
250+
if !isAudienceValid(aud, v.audiences) {
251+
v.logger.ErrorContext(ctx, "token audience validation failed", "audiences", aud)
224252
return claims, ErrInvalidAud
225253
}
226254

0 commit comments

Comments
 (0)