diff --git a/oidc/oidc.go b/oidc/oidc.go index 17419f3..e5aaef7 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -105,26 +105,29 @@ type Provider struct { // Raw claims returned by the server. rawClaims []byte - // Guards all of the following fields. - mu sync.Mutex // HTTP client specified from the initial NewProvider request. This is used // when creating the common key set. client *http.Client + + // Guards all of the following fields. + mu sync.RWMutex // A key set that uses context.Background() and is shared between all code paths // that don't have a convinent way of supplying a unique context. commonRemoteKeySet KeySet } -func (p *Provider) remoteKeySet() KeySet { +func (p *Provider) remoteKeySet(c *http.Client) KeySet { + p.mu.RLock() + if p.commonRemoteKeySet != nil { + defer p.mu.RUnlock() + return p.commonRemoteKeySet + } + p.mu.RUnlock() + p.mu.Lock() defer p.mu.Unlock() - if p.commonRemoteKeySet == nil { - ctx := context.Background() - if p.client != nil { - ctx = ClientContext(ctx, p.client) - } - p.commonRemoteKeySet = NewRemoteKeySet(ctx, p.jwksURL) - } + + p.commonRemoteKeySet = NewRemoteKeySet(ClientContext(context.Background(), c), p.jwksURL) return p.commonRemoteKeySet } @@ -350,7 +353,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) ct := resp.Header.Get("Content-Type") mediaType, _, parseErr := mime.ParseMediaType(ct) if parseErr == nil && mediaType == "application/jwt" { - payload, err := p.remoteKeySet().VerifySignature(ctx, string(body)) + payload, err := p.remoteKeySet(getClient(ctx)).VerifySignature(ctx, string(body)) if err != nil { return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err) } diff --git a/oidc/verify.go b/oidc/verify.go index 52b27b7..9c97b23 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -131,7 +131,7 @@ func (p *Provider) VerifierContext(ctx context.Context, config *Config) *IDToken // The returned verifier uses a background context for all requests to the upstream // JWKs endpoint. To control that context, use VerifierContext instead. func (p *Provider) Verifier(config *Config) *IDTokenVerifier { - return p.newVerifier(p.remoteKeySet(), config) + return p.newVerifier(p.remoteKeySet(p.client), config) } func (p *Provider) newVerifier(keySet KeySet, config *Config) *IDTokenVerifier {