Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ var (
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
)

// IssuerMismatchError is returned when the issuer does not match the expected value.
type IssuerMismatchError struct {
Expected string
Actual string
}

func (e *IssuerMismatchError) Error() string {
return fmt.Sprintf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", e.Expected, e.Actual)
}

type contextKey int

var issuerURLKey contextKey
Expand Down Expand Up @@ -162,7 +172,7 @@ var supportedAlgorithms = map[string]bool{
// parsing.
//
// // Directly fetch the metadata document.
// resp, err := http.Get("https://login.example.com/custom-metadata-path")
// resp, err := http.Get("https://login.example.com/custom-metadata-path")
// if err != nil {
// // ...
// }
Expand Down Expand Up @@ -267,7 +277,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
issuerURL = issuer
}
if p.Issuer != issuerURL && !skipIssuerValidation {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
return nil, &IssuerMismatchError{Expected: issuer, Actual: p.Issuer}
}
var algs []string
for _, a := range p.Algorithms {
Expand Down
34 changes: 21 additions & 13 deletions oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -116,17 +117,18 @@ func TestAccessTokenVerification(t *testing.T) {

func TestNewProvider(t *testing.T) {
tests := []struct {
name string
data string
issuerURLOverride string
trailingSlash bool
wantAuthURL string
wantTokenURL string
wantDeviceAuthURL string
wantUserInfoURL string
wantIssuerURL string
wantAlgorithms []string
wantErr bool
name string
data string
issuerURLOverride string
trailingSlash bool
wantAuthURL string
wantTokenURL string
wantDeviceAuthURL string
wantUserInfoURL string
wantIssuerURL string
wantAlgorithms []string
wantErr bool
wantErrIssuerMismatch bool
}{
{
name: "basic_case",
Expand Down Expand Up @@ -306,14 +308,20 @@ func TestNewProvider(t *testing.T) {

p, err := NewProvider(ctx, issuer)
if err != nil {
if !test.wantErr {
if !test.wantErr && !test.wantErrIssuerMismatch {
t.Errorf("NewProvider() failed: %v", err)
}
return
}
if test.wantErr {
if test.wantErr || test.wantErrIssuerMismatch {
t.Fatalf("NewProvider(): expected error")
}
if test.wantErrIssuerMismatch {
var errExp *IssuerMismatchError
if !errors.As(err, &errExp) {
t.Errorf("expected *IssuerMismatchError but got %q", err)
}
}

if test.wantIssuerURL != "" && p.issuer != test.wantIssuerURL {
t.Errorf("NewProvider() unexpected issuer value, got=%s, want=%s",
Expand Down