diff --git a/oidc/errors.go b/oidc/errors.go new file mode 100644 index 00000000..917c3b0f --- /dev/null +++ b/oidc/errors.go @@ -0,0 +1,43 @@ +package oidc + +import ( + "fmt" + "time" +) + +// TokenExpiredError indicates that Verify failed because the token was expired. This +// error does NOT indicate that the token is not also invalid for other reasons. Other +// checks might have failed if the expiration check had not failed. +type TokenExpiredError struct { + // Expiry is the time when the token expired. + Expiry time.Time +} + +func (e *TokenExpiredError) Error() string { + return fmt.Sprintf("oidc: token is expired (Token Expiry: %v)", e.Expiry) +} + +// InvalidIssuerError indicates that Verify failed because the token was issued +// by an unexpected issuer. This error does NOT indicate that the token is not +// also invalid for other reasons. Other checks might have failed if the issuer +// check had not failed. +type InvalidIssuerError struct { + Expected, Actual string +} + +func (e *InvalidIssuerError) Error() string { + return fmt.Sprintf("oidc: id token issued by a different provider, expected %q got %q", e.Expected, e.Actual) +} + +// InvalidAudienceError indicates that Verify failed because the token was +// intended for a different audience. This error does NOT indicate that the +// token is not also invalid for other reasons. Other checks might have failed +// if the audience check had not failed. +type InvalidAudienceError struct { + Expected string + Actual []string +} + +func (e *InvalidAudienceError) Error() string { + return fmt.Sprintf("oidc: expected audience %q got %q", e.Expected, e.Actual) +} diff --git a/oidc/verify.go b/oidc/verify.go index 0bca49a8..776657b1 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -21,19 +21,7 @@ const ( issuerGoogleAccountsNoScheme = "accounts.google.com" ) -// TokenExpiredError indicates that Verify failed because the token was expired. This -// error does NOT indicate that the token is not also invalid for other reasons. Other -// checks might have failed if the expiration check had not failed. -type TokenExpiredError struct { - // Expiry is the time when the token expired. - Expiry time.Time -} - -func (e *TokenExpiredError) Error() string { - return fmt.Sprintf("oidc: token is expired (Token Expiry: %v)", e.Expiry) -} - -// KeySet is a set of publc JSON Web Keys that can be used to validate the signature +// KeySet is a set of public JSON Web Keys that can be used to validate the signature // of JSON web tokens. This is expected to be backed by a remote key set through // provider metadata discovery or an in-memory set of keys delivered out-of-band. type KeySet interface { @@ -264,7 +252,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok // // We will not add hooks to let other providers go off spec like this. if !(v.issuer == issuerGoogleAccounts && t.Issuer == issuerGoogleAccountsNoScheme) { - return nil, fmt.Errorf("oidc: id token issued by a different provider, expected %q got %q", v.issuer, t.Issuer) + return nil, &InvalidIssuerError{Expected: v.issuer, Actual: t.Issuer} } } @@ -274,7 +262,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok if !v.config.SkipClientIDCheck { if v.config.ClientID != "" { if !contains(t.Audience, v.config.ClientID) { - return nil, fmt.Errorf("oidc: expected audience %q got %q", v.config.ClientID, t.Audience) + return nil, &InvalidAudienceError{Expected: v.config.ClientID, Actual: t.Audience} } } else { return nil, fmt.Errorf("oidc: invalid configuration, clientID must be provided or SkipClientIDCheck must be set") diff --git a/oidc/verify_test.go b/oidc/verify_test.go index f2e2433b..743d02d5 100644 --- a/oidc/verify_test.go +++ b/oidc/verify_test.go @@ -5,11 +5,13 @@ import ( "crypto" "encoding/base64" "errors" + "fmt" "io" "net/http" "net/http/httptest" "reflect" "strconv" + "strings" "testing" "time" ) @@ -24,6 +26,7 @@ func TestVerify(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "good eddsa token", @@ -34,6 +37,7 @@ func TestVerify(t *testing.T) { SupportedSigningAlgs: []string{EdDSA}, }, signKey: newEdDSAKey(t), + errFunc: expectSuccess, }, { name: "invalid issuer", @@ -44,7 +48,10 @@ func TestVerify(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), - wantErr: true, + errFunc: expectAll( + expectErrorType[*InvalidIssuerError], + expectErrorMessage(`oidc: id token issued by a different provider, expected "https://bar" got "https://foo"`), + ), }, { name: "skip issuer check", @@ -56,6 +63,7 @@ func TestVerify(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "invalid sig", @@ -66,7 +74,7 @@ func TestVerify(t *testing.T) { }, signKey: newRSAKey(t), verificationKey: newRSAKey(t), - wantErr: true, + errFunc: expectError, }, { name: "google accounts without scheme", @@ -77,6 +85,7 @@ func TestVerify(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "expired token", @@ -84,8 +93,8 @@ func TestVerify(t *testing.T) { config: Config{ SkipClientIDCheck: true, }, - signKey: newRSAKey(t), - wantErrExpiry: true, + signKey: newRSAKey(t), + errFunc: expectErrorType[*TokenExpiredError], }, { name: "unexpired token", @@ -94,6 +103,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "expiry as float", @@ -104,6 +114,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "nbf in future", @@ -113,7 +124,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, }, signKey: newRSAKey(t), - wantErr: true, + errFunc: expectError, }, { name: "nbf in past", @@ -123,6 +134,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "nbf in future within clock skew tolerance", @@ -132,6 +144,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "unsigned token", @@ -140,7 +153,7 @@ func TestVerify(t *testing.T) { SkipClientIDCheck: true, SkipExpiryCheck: true, }, - wantErr: true, + errFunc: expectError, }, { name: "unsigned token InsecureSkipSignatureCheck", @@ -150,6 +163,7 @@ func TestVerify(t *testing.T) { SkipExpiryCheck: true, InsecureSkipSignatureCheck: true, }, + errFunc: expectSuccess, }, } for _, test := range tests { @@ -167,6 +181,7 @@ func TestVerifyAudience(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "mismatched audience", @@ -176,7 +191,10 @@ func TestVerifyAudience(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), - wantErr: true, + errFunc: expectAll( + expectErrorType[*InvalidAudienceError], + expectErrorMessage(`oidc: expected audience "client1" got ["client2"]`), + ), }, { name: "multiple audiences, one matches", @@ -186,6 +204,7 @@ func TestVerifyAudience(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, } for _, test := range tests { @@ -203,6 +222,7 @@ func TestVerifySigningAlg(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, }, { name: "bad signing alg", @@ -212,7 +232,7 @@ func TestVerifySigningAlg(t *testing.T) { SkipExpiryCheck: true, }, signKey: newECDSAKey(t), - wantErr: true, + errFunc: expectError, }, { name: "ecdsa signing", @@ -223,6 +243,7 @@ func TestVerifySigningAlg(t *testing.T) { SkipExpiryCheck: true, }, signKey: newECDSAKey(t), + errFunc: expectSuccess, }, { name: "eddsa signing", @@ -233,6 +254,7 @@ func TestVerifySigningAlg(t *testing.T) { SupportedSigningAlgs: []string{EdDSA}, }, signKey: newEdDSAKey(t), + errFunc: expectSuccess, }, { name: "one of many supported", @@ -243,6 +265,7 @@ func TestVerifySigningAlg(t *testing.T) { SupportedSigningAlgs: []string{RS256, ES256}, }, signKey: newECDSAKey(t), + errFunc: expectSuccess, }, { name: "not in requiredAlgs", @@ -253,7 +276,7 @@ func TestVerifySigningAlg(t *testing.T) { SkipExpiryCheck: true, }, signKey: newECDSAKey(t), - wantErr: true, + errFunc: expectError, }, } for _, test := range tests { @@ -271,6 +294,7 @@ func TestAccessTokenHash(t *testing.T) { SkipExpiryCheck: true, }, signKey: newRSAKey(t), + errFunc: expectSuccess, } t.Run("at_hash", func(t *testing.T) { tok, err := vt.runGetToken(t) @@ -324,7 +348,7 @@ func TestDistributedClaims(t *testing.T) { signKey: newRSAKey(t), }, want: map[string]claimSource{ - "address": claimSource{Endpoint: "123", AccessToken: "1234"}, + "address": {Endpoint: "123", AccessToken: "1234"}, }, }, { @@ -347,8 +371,8 @@ func TestDistributedClaims(t *testing.T) { signKey: newRSAKey(t), }, want: map[string]claimSource{ - "address": claimSource{Endpoint: "123", AccessToken: "1234"}, - "phone_number": claimSource{Endpoint: "123", AccessToken: "1234"}, + "address": {Endpoint: "123", AccessToken: "1234"}, + "phone_number": {Endpoint: "123", AccessToken: "1234"}, }, }, { @@ -554,6 +578,8 @@ func (v resolverTest) testEndpoint(t *testing.T) ([]byte, error) { return resolveDistributedClaim(ctx, verifier, src) } +type errCheck func(error) string + type verificationTest struct { // Name of the subtest. name string @@ -570,9 +596,8 @@ type verificationTest struct { // testing invalid signatures. verificationKey *signingKey - config Config - wantErr bool - wantErrExpiry bool + config Config + errFunc func(error) string } func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) { @@ -605,16 +630,60 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) { func (v verificationTest) run(t *testing.T) { _, err := v.runGetToken(t) - if err != nil && !v.wantErr && !v.wantErrExpiry { - t.Errorf("%v", err) + if msg := v.errFunc(err); msg != "" { + t.Error(msg) + } +} + +func expectError(err error) string { + if err == nil { + return "expected error, got nil" + } + + return "" +} + +func expectSuccess(err error) string { + if err != nil { + return fmt.Sprintf("expected no error, got %v", err) } - if err == nil && (v.wantErr || v.wantErrExpiry) { - t.Errorf("expected error") + + return "" +} + +func expectErrorType[T error](err error) string { + var errT T + if !errors.As(err, &errT) { + return fmt.Sprintf("expected error type %T but got %T", errT, err) + } + + return "" +} + +func expectAll(checks ...errCheck) errCheck { + return func(err error) string { + var messages []string + + for _, check := range checks { + if msg := check(err); msg != "" { + messages = append(messages, msg) + } + } + + return strings.Join(messages, "\n") } - if v.wantErrExpiry { - var errExp *TokenExpiredError - if !errors.As(err, &errExp) { - t.Errorf("expected *TokenExpiryError but got %q", err) +} + +func expectErrorMessage(expectedMsg string) errCheck { + return func(err error) string { + if err == nil { + return fmt.Sprintf("expected error %q, got nil", expectedMsg) + } + + if err.Error() != expectedMsg { + return fmt.Sprintf("expected error %q, got %q", expectedMsg, err.Error()) } + + return "" } }