From 71639a2d8997480a5d85442cee80adb856f445bd Mon Sep 17 00:00:00 2001 From: mstott2 Date: Mon, 14 Jul 2025 16:24:11 -0400 Subject: [PATCH 1/2] oidc: add typed error for issuer mismatch Enables OIDC implementations to better distinguish and act on issuer mismatch errors --- oidc/oidc.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/oidc/oidc.go b/oidc/oidc.go index f6a7ea8..2fad308 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -40,6 +40,16 @@ var ( errInvalidAtHash = errors.New("access token hash does not match value in ID token") ) +// IssuerMismatchError is returned when the issuer does not match the expected value. +type IssuerMismatchError struct { + Expected string + Actual string +} + +func (e *IssuerMismatchError) Error() string { + return fmt.Sprintf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", e.Expected, e.Actual) +} + type contextKey int var issuerURLKey contextKey @@ -162,7 +172,7 @@ var supportedAlgorithms = map[string]bool{ // parsing. // // // Directly fetch the metadata document. -// resp, err := http.Get("https://login.example.com/custom-metadata-path") +// resp, err := http.Get("https://login.example.com/custom-metadata-path") // if err != nil { // // ... // } @@ -267,7 +277,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) { issuerURL = issuer } if p.Issuer != issuerURL && !skipIssuerValidation { - return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer) + return nil, &IssuerMismatchError{Expected: issuer, Actual: p.Issuer} } var algs []string for _, a := range p.Algorithms { From 73acf350bf99b5fab2a14391133a101f86d33a07 Mon Sep 17 00:00:00 2001 From: mstott2 Date: Mon, 14 Jul 2025 17:15:10 -0400 Subject: [PATCH 2/2] oidc: update oidc tests --- oidc/oidc_test.go | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/oidc/oidc_test.go b/oidc/oidc_test.go index 27415dd..2b50bfb 100644 --- a/oidc/oidc_test.go +++ b/oidc/oidc_test.go @@ -6,6 +6,7 @@ import ( "crypto/elliptic" "crypto/rand" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -116,17 +117,18 @@ func TestAccessTokenVerification(t *testing.T) { func TestNewProvider(t *testing.T) { tests := []struct { - name string - data string - issuerURLOverride string - trailingSlash bool - wantAuthURL string - wantTokenURL string - wantDeviceAuthURL string - wantUserInfoURL string - wantIssuerURL string - wantAlgorithms []string - wantErr bool + name string + data string + issuerURLOverride string + trailingSlash bool + wantAuthURL string + wantTokenURL string + wantDeviceAuthURL string + wantUserInfoURL string + wantIssuerURL string + wantAlgorithms []string + wantErr bool + wantErrIssuerMismatch bool }{ { name: "basic_case", @@ -306,14 +308,20 @@ func TestNewProvider(t *testing.T) { p, err := NewProvider(ctx, issuer) if err != nil { - if !test.wantErr { + if !test.wantErr && !test.wantErrIssuerMismatch { t.Errorf("NewProvider() failed: %v", err) } return } - if test.wantErr { + if test.wantErr || test.wantErrIssuerMismatch { t.Fatalf("NewProvider(): expected error") } + if test.wantErrIssuerMismatch { + var errExp *IssuerMismatchError + if !errors.As(err, &errExp) { + t.Errorf("expected *IssuerMismatchError but got %q", err) + } + } if test.wantIssuerURL != "" && p.issuer != test.wantIssuerURL { t.Errorf("NewProvider() unexpected issuer value, got=%s, want=%s",