diff --git a/internal/wrappers/azure-http.go b/internal/wrappers/azure-http.go index 91409ba17..8ce088218 100644 --- a/internal/wrappers/azure-http.go +++ b/internal/wrappers/azure-http.go @@ -108,7 +108,12 @@ func (g *AzureHTTPWrapper) get( queryParams map[string]string, authFormat string, ) (bool, error) { - resp, err := GetWithQueryParams(g.client, url, token, authFormat, queryParams) + resp, err := WithSCMRateLimitRetry( + AzureRateLimitConfig, + func() (*http.Response, error) { + return GetWithQueryParams(g.client, url, token, authFormat, queryParams) + }, + ) if err != nil { return false, err } diff --git a/internal/wrappers/bitbucket-http.go b/internal/wrappers/bitbucket-http.go index 8be2abf34..2a6e3af85 100644 --- a/internal/wrappers/bitbucket-http.go +++ b/internal/wrappers/bitbucket-http.go @@ -150,7 +150,12 @@ func (g *BitBucketHTTPWrapper) getFromBitBucket( logger.PrintIfVerbose(fmt.Sprintf("Request to %s", url)) - resp, err := GetWithQueryParams(g.client, url, token, basicFormat, queryParams) + resp, err := WithSCMRateLimitRetry( + BitbucketRateLimitConfig, + func() (*http.Response, error) { + return GetWithQueryParams(g.client, url, token, basicFormat, queryParams) + }, + ) if err != nil { return err } @@ -264,7 +269,12 @@ func collectPageBitBucket( } func getBitBucket(client *http.Client, token, url string, target interface{}, queryParams map[string]string) error { - resp, err := GetWithQueryParams(client, url, token, basicFormat, queryParams) + resp, err := WithSCMRateLimitRetry( + BitbucketRateLimitConfig, + func() (*http.Response, error) { + return GetWithQueryParams(client, url, token, basicFormat, queryParams) + }, + ) if err != nil { return err } diff --git a/internal/wrappers/bitbucketserver/bitbucket-server-http.go b/internal/wrappers/bitbucketserver/bitbucket-server-http.go index b37b5f8d1..15552f279 100644 --- a/internal/wrappers/bitbucketserver/bitbucket-server-http.go +++ b/internal/wrappers/bitbucketserver/bitbucket-server-http.go @@ -162,7 +162,12 @@ func getBitBucketServer( } req.URL.RawQuery = q.Encode() - resp, err := client.Do(req) + resp, err := wrappers.WithSCMRateLimitRetry( + wrappers.BitbucketRateLimitConfig, + func() (*http.Response, error) { + return client.Do(req) + }, + ) if err != nil { return err } diff --git a/internal/wrappers/github-http.go b/internal/wrappers/github-http.go index 6ba2d31e0..c8638041a 100644 --- a/internal/wrappers/github-http.go +++ b/internal/wrappers/github-http.go @@ -244,7 +244,12 @@ func get(client *http.Client, url string, target interface{}, queryParams map[st req.Header.Add(acceptHeader, apiVersion) token := viper.GetString(params.SCMTokenFlag) logger.PrintRequest(req) - resp, err := GetWithQueryParamsAndCustomRequest(client, req, url, token, tokenFormat, queryParams) + resp, err := WithSCMRateLimitRetry( + GitHubRateLimitConfig, + func() (*http.Response, error) { + return GetWithQueryParamsAndCustomRequest(client, req, url, token, tokenFormat, queryParams) + }, + ) if err != nil { return nil, err } diff --git a/internal/wrappers/gitlab-http.go b/internal/wrappers/gitlab-http.go index 6b4b53d66..91ac46823 100644 --- a/internal/wrappers/gitlab-http.go +++ b/internal/wrappers/gitlab-http.go @@ -137,7 +137,12 @@ func getFromGitLab( logger.PrintRequest(req) - resp, err := GetWithQueryParamsAndCustomRequest(client, req, requestURL, token, bearerFormat, queryParams) + resp, err := WithSCMRateLimitRetry( + GitLabRateLimitConfig, + func() (*http.Response, error) { + return GetWithQueryParamsAndCustomRequest(client, req, requestURL, token, bearerFormat, queryParams) + }, + ) if err != nil { return nil, err } diff --git a/internal/wrappers/rate-limit.go b/internal/wrappers/rate-limit.go new file mode 100644 index 000000000..b204e9d00 --- /dev/null +++ b/internal/wrappers/rate-limit.go @@ -0,0 +1,165 @@ +package wrappers + +import ( + "log" + "net/http" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +const defaultRateLimitWaitSeconds = 60 + +// SCMRateLimitConfig holds rate limit configuration for different SCM providers +type SCMRateLimitConfig struct { + Provider string + ResetHeaderName string + RemainingHeaderName string + LimitHeaderName string + RateLimitStatusCodes []int + DefaultWaitTime time.Duration +} + +// Common SCM rate limit configurations +var ( + GitHubRateLimitConfig = &SCMRateLimitConfig{ + Provider: "GitHub", + ResetHeaderName: "X-RateLimit-Reset", + RemainingHeaderName: "X-RateLimit-Remaining", + LimitHeaderName: "X-RateLimit-Limit", + RateLimitStatusCodes: []int{403, 429}, + DefaultWaitTime: defaultRateLimitWaitSeconds * time.Second, + } + + GitLabRateLimitConfig = &SCMRateLimitConfig{ + Provider: "GitLab", + ResetHeaderName: "RateLimit-Reset", + RemainingHeaderName: "RateLimit-Remaining", + LimitHeaderName: "RateLimit-Limit", + RateLimitStatusCodes: []int{429}, + DefaultWaitTime: defaultRateLimitWaitSeconds * time.Second, + } + + BitbucketRateLimitConfig = &SCMRateLimitConfig{ + Provider: "Bitbucket", + ResetHeaderName: "X-RateLimit-Reset", + RemainingHeaderName: "X-RateLimit-Remaining", + LimitHeaderName: "X-RateLimit-Limit", + RateLimitStatusCodes: []int{429}, + DefaultWaitTime: defaultRateLimitWaitSeconds * time.Second, + } + + AzureRateLimitConfig = &SCMRateLimitConfig{ + Provider: "Azure", + ResetHeaderName: "X-Ratelimit-Reset", + RemainingHeaderName: "X-Ratelimit-Remaining", + LimitHeaderName: "X-Ratelimit-Limit", + RateLimitStatusCodes: []int{429}, + DefaultWaitTime: defaultRateLimitWaitSeconds * time.Second, + } +) + +// SCMRateLimitError represents a rate limit error from any SCM provider +type SCMRateLimitError struct { + Provider string + ResetTime int64 + Message string +} + +func (e *SCMRateLimitError) Error() string { + if e.Message != "" { + return e.Message + } + return e.Provider + " API rate limit exceeded" +} + +func (e *SCMRateLimitError) RetryAfter() time.Duration { + if e.ResetTime > 0 { + reset := time.Unix(e.ResetTime, 0) + now := time.Now() + if reset.After(now) { + return reset.Sub(now) + (defaultRateLimitWaitSeconds * time.Second) // add buffer for 60 seconds + } + } + return defaultRateLimitWaitSeconds * time.Second +} + +// WithSCMRateLimitRetry wraps any SCM API call with rate limit retry logic +func WithSCMRateLimitRetry(config *SCMRateLimitConfig, apiCall func() (*http.Response, error)) (*http.Response, error) { + maxRetries := 3 + retryCount := 0 + + for { + resp, err := apiCall() + if err != nil { + return nil, err + } + + // Check if it's a rate limit error + if isRateLimitStatusCode(resp.StatusCode, config) { + rateLimitErr := ParseRateLimitHeaders(resp.Header, config) + wait := config.DefaultWaitTime + if rateLimitErr != nil { + wait = rateLimitErr.RetryAfter() + } + if retryCount >= maxRetries { + return nil, errors.Errorf("%s API rate limit exceeded after %d retries", config.Provider, maxRetries) + } + log.Printf("%s API rate limit exceeded (status %d). Waiting %v until %v before retrying... (attempt %d/%d)", + config.Provider, resp.StatusCode, wait, time.Now().Add(wait), retryCount+1, maxRetries) + time.Sleep(wait) + // Reset Authorization header before retry + if resp.Request != nil { + resetAuthorizationHeader(resp.Request) + } + retryCount++ + continue + } + return resp, err + } +} + +// ParseRateLimitHeaders extracts rate limit information from HTTP response headers +func ParseRateLimitHeaders(headers map[string][]string, config *SCMRateLimitConfig) *SCMRateLimitError { + resetHeader := getHeaderValue(headers, config.ResetHeaderName) + if resetHeader == "" { + return nil + } + + resetTime, err := strconv.ParseInt(resetHeader, 10, 64) + if err != nil { + return nil + } + + return &SCMRateLimitError{ + Provider: config.Provider, + ResetTime: resetTime, + } +} + +// getHeaderValue retrieves a header value in a case-insensitive manner +func getHeaderValue(headers map[string][]string, headerName string) string { + for name, values := range headers { + if strings.EqualFold(name, headerName) && len(values) > 0 { + return values[0] + } + } + return "" +} + +// isRateLimitStatusCode checks if the status code indicates a rate limit error +func isRateLimitStatusCode(statusCode int, config *SCMRateLimitConfig) bool { + for _, code := range config.RateLimitStatusCodes { + if statusCode == code { + return true + } + } + return false +} + +// resetAuthorizationHeader removes the Authorization header from the request +func resetAuthorizationHeader(req *http.Request) { + req.Header.Del("Authorization") +} diff --git a/test/integration/rate-limit_test.go b/test/integration/rate-limit_test.go new file mode 100644 index 000000000..05a0efcbb --- /dev/null +++ b/test/integration/rate-limit_test.go @@ -0,0 +1,78 @@ +package integration + +import ( + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/stretchr/testify/assert" +) + +func mockAPI(repeatCode, repeatCount int, headerName, headerValue string) func() (*http.Response, error) { + attempt := 0 + return func() (*http.Response, error) { + rec := httptest.NewRecorder() + if attempt < repeatCount { + rec.Code = repeatCode + if headerName != "" { + rec.Header().Set(headerName, headerValue) + } + } else { + rec.Code = http.StatusOK + } + attempt++ + resp := rec.Result() + resp.Body = io.NopCloser(strings.NewReader("")) + return resp, nil + } +} + +func runRateLimitTest(t *testing.T, config *wrappers.SCMRateLimitConfig, repeatCode, repeatCount int, headerName string) { + reset := strconv.FormatInt(time.Now().Unix(), 10) // simulate immediate retry + + //nolint:bodyclose // safe in test, body closed later + api := mockAPI(repeatCode, repeatCount, headerName, reset) + + start := time.Now() + resp, err := wrappers.WithSCMRateLimitRetry(config, api) + if resp != nil { + defer resp.Body.Close() + } + + assert := assert.New(t) + assert.NoError(err) + assert.NotNil(resp) + assert.Equal(http.StatusOK, resp.StatusCode) + + elapsed := time.Since(start) + assert.GreaterOrEqual(elapsed, config.DefaultWaitTime) +} + +func TestGitHubRateLimit_SuccessAfterRetryOne(t *testing.T) { + runRateLimitTest(t, wrappers.GitHubRateLimitConfig, 429, 1, "X-RateLimit-Reset") +} + +func TestGitHubRateLimit_SuccessAfterRetryTwo(t *testing.T) { + runRateLimitTest(t, wrappers.GitHubRateLimitConfig, 429, 2, "X-RateLimit-Reset") +} + +func TestGitHubRateLimit_SuccessAfterRetryThree(t *testing.T) { + runRateLimitTest(t, wrappers.GitHubRateLimitConfig, 403, 3, "X-RateLimit-Reset") +} + +func TestGitLabRateLimit_SuccessAfterRetryOne(t *testing.T) { + runRateLimitTest(t, wrappers.GitLabRateLimitConfig, 429, 1, "RateLimit-Reset") +} + +func TestBitBucketRateLimit_SuccessAfterRetryOne(t *testing.T) { + runRateLimitTest(t, wrappers.BitbucketRateLimitConfig, 429, 1, "X-RateLimit-Reset") +} + +func TestAzureRateLimit_SuccessAfterRetryOne(t *testing.T) { + runRateLimitTest(t, wrappers.AzureRateLimitConfig, 429, 1, "X-Ratelimit-Reset") +}