diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 244de0d..9d738b8 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -74,7 +74,7 @@ jobs: make test - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: flags: ${{ matrix.os }},go-${{ matrix.go }} token: ${{ secrets.CODECOV_TOKEN }} diff --git a/discovery/discovery.go b/discovery/discovery.go index d341405..5081f79 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -15,6 +15,7 @@ import ( "time" retry "github.com/appleboy/go-httpretry" + "golang.org/x/sync/singleflight" "github.com/go-authgate/sdk-go/oauth" ) @@ -86,6 +87,7 @@ type Client struct { mu sync.RWMutex cached *Metadata fetchedAt time.Time + group singleflight.Group } // Option configures a discovery Client. @@ -144,58 +146,70 @@ func (c *Client) Fetch(ctx context.Context) (*Metadata, error) { } // refresh fetches fresh metadata from the discovery endpoint. +// singleflight coalesces concurrent misses into one HTTP request; the lock is +// held only for the cache check and cache update, not during the network call. func (c *Client) refresh(ctx context.Context) (*Metadata, error) { - c.mu.Lock() - defer c.mu.Unlock() - - // Double-check after acquiring write lock - if c.cached != nil && time.Since(c.fetchedAt) < c.cacheTTL { - return cloneMetadata(c.cached), nil - } + v, err, _ := c.group.Do("fetch", func() (any, error) { + // Double-check after coalescing into the singleflight slot. + c.mu.RLock() + if c.cached != nil && time.Since(c.fetchedAt) < c.cacheTTL { + cp := cloneMetadata(c.cached) + c.mu.RUnlock() + return cp, nil + } + c.mu.RUnlock() - discoveryURL := c.issuerURL + wellKnownPath - resp, err := c.httpClient.Get(ctx, discoveryURL) - if err != nil { - return nil, fmt.Errorf("discovery: fetch %s: %w", discoveryURL, err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf( - "discovery: unexpected status %d from %s", - resp.StatusCode, - discoveryURL, - ) - } + // HTTP fetch happens outside any lock. + discoveryURL := c.issuerURL + wellKnownPath + resp, err := c.httpClient.Get(ctx, discoveryURL) + if err != nil { + return nil, fmt.Errorf("discovery: fetch %s: %w", discoveryURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "discovery: unexpected status %d from %s", + resp.StatusCode, + discoveryURL, + ) + } - var meta Metadata - if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { - return nil, fmt.Errorf("discovery: decode response: %w", err) - } + var meta Metadata + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return nil, fmt.Errorf("discovery: decode response: %w", err) + } - // Validate issuer matches the expected URL (OIDC Discovery 1.0 §4.3) - issuer := strings.TrimRight(meta.Issuer, "/") - if issuer != c.issuerURL { - return nil, fmt.Errorf( - "discovery: issuer mismatch: got %q, expected %q", - meta.Issuer, - c.issuerURL, - ) - } + // Validate issuer matches the expected URL (OIDC Discovery 1.0 §4.3) + issuer := strings.TrimRight(meta.Issuer, "/") + if issuer != c.issuerURL { + return nil, fmt.Errorf( + "discovery: issuer mismatch: got %q, expected %q", + meta.Issuer, + c.issuerURL, + ) + } - // AuthGate uses a fixed device authorization path. Derive it from issuer - // when not explicitly advertised in the discovery response. - if meta.DeviceAuthorizationEndpoint == "" { - meta.DeviceAuthorizationEndpoint = issuer + "/oauth/device/code" - } + // AuthGate uses a fixed device authorization path. Derive it from issuer + // when not explicitly advertised in the discovery response. + if meta.DeviceAuthorizationEndpoint == "" { + meta.DeviceAuthorizationEndpoint = issuer + "/oauth/device/code" + } - // AuthGate has /oauth/introspect but doesn't yet advertise it in discovery - if meta.IntrospectionEndpoint == "" { - meta.IntrospectionEndpoint = issuer + "/oauth/introspect" - } + // AuthGate has /oauth/introspect but doesn't yet advertise it in discovery + if meta.IntrospectionEndpoint == "" { + meta.IntrospectionEndpoint = issuer + "/oauth/introspect" + } - c.cached = &meta - c.fetchedAt = time.Now() + c.mu.Lock() + c.cached = &meta + c.fetchedAt = time.Now() + c.mu.Unlock() - return cloneMetadata(&meta), nil + return cloneMetadata(&meta), nil + }) + if err != nil { + return nil, err + } + return v.(*Metadata), nil } diff --git a/middleware/middleware.go b/middleware/middleware.go index ec4c3c9..1371407 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -108,12 +108,20 @@ type errorResponse struct { func defaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) { var oauthErr *oauth.Error if errors.As(err, &oauthErr) { - if oauthErr.Code == "server_error" { + switch oauthErr.Code { + case "server_error": writeJSON(w, http.StatusInternalServerError, errorResponse{ Error: oauthErr.Code, Description: oauthErr.Description, }) return + case "insufficient_scope": + w.Header().Set("WWW-Authenticate", `Bearer error="insufficient_scope"`) + writeJSON(w, http.StatusForbidden, errorResponse{ + Error: oauthErr.Code, + Description: oauthErr.Description, + }) + return } // All other OAuth errors → 401 with WWW-Authenticate @@ -178,7 +186,10 @@ func BearerAuth(opts ...Option) func(http.Handler) http.Handler { // Check required scopes for _, scope := range cfg.requiredScopes { if !info.HasScope(scope) { - writeInsufficientScope(w, scope) + cfg.errorHandler(w, r, &oauth.Error{ + Code: "insufficient_scope", + Description: "Token does not have required scope: " + scope, + }) return } }