Skip to content

Commit ca982f9

Browse files
authored
fix: Clean up usage retry logic (#1950)
I may have used the term "clean up" loosely.
1 parent 78df77d commit ca982f9

File tree

2 files changed

+274
-197
lines changed

2 files changed

+274
-197
lines changed

premium/usage.go

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"github.com/cloudquery/cloudquery-api-go/config"
2222
"github.com/cloudquery/plugin-sdk/v4/plugin"
2323
"github.com/google/uuid"
24+
"github.com/hashicorp/go-retryablehttp"
2425
"github.com/rs/zerolog"
25-
"github.com/rs/zerolog/log"
2626
)
2727

2828
const (
@@ -32,6 +32,9 @@ const (
3232
defaultMaxWaitTime = 60 * time.Second
3333
defaultMinTimeBetweenFlushes = 10 * time.Second
3434
defaultMaxTimeBetweenFlushes = 30 * time.Second
35+
36+
marketplaceDuplicateWaitTime = 1 * time.Second
37+
marketplaceMinRetries = 20
3538
)
3639

3740
const (
@@ -109,7 +112,9 @@ func WithMinTimeBetweenFlushes(minTimeBetweenFlushes time.Duration) UsageClientO
109112
// WithMaxRetries sets the maximum number of retries to update the usage in case of an API error
110113
func WithMaxRetries(maxRetries int) UsageClientOptions {
111114
return func(updater *BatchUpdater) {
112-
updater.maxRetries = maxRetries
115+
if maxRetries > 0 {
116+
updater.maxRetries = maxRetries
117+
}
113118
}
114119
}
115120

@@ -198,6 +203,9 @@ type BatchUpdater struct {
198203
isClosed bool
199204
dataOnClose bool
200205
usageIncreaseMethod int
206+
207+
// Testing
208+
timeFunc func() time.Time
201209
}
202210

203211
func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, error) {
@@ -216,6 +224,7 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e
216224
triggerUpdate: make(chan struct{}),
217225
done: make(chan struct{}),
218226
closeError: make(chan error),
227+
timeFunc: time.Now,
219228

220229
tables: map[string]uint32{},
221230
}
@@ -246,18 +255,26 @@ func NewUsageClient(meta plugin.Meta, ops ...UsageClientOptions) (UsageClient, e
246255

247256
// Create a default api client if none was provided
248257
if u.apiClient == nil {
249-
ac, err := cqapi.NewClientWithResponses(u.url, cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
250-
token, err := u.tokenClient.GetToken()
251-
if err != nil {
252-
return fmt.Errorf("failed to get token: %w", err)
253-
}
254-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
255-
return nil
256-
}))
258+
retryClient := retryablehttp.NewClient()
259+
retryClient.Logger = nil
260+
retryClient.RetryMax = u.maxRetries
261+
retryClient.RetryWaitMax = u.maxWaitTime
262+
263+
var err error
264+
u.apiClient, err = cqapi.NewClientWithResponses(u.url,
265+
cqapi.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
266+
token, err := u.tokenClient.GetToken()
267+
if err != nil {
268+
return fmt.Errorf("failed to get token: %w", err)
269+
}
270+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
271+
return nil
272+
}),
273+
cqapi.WithHTTPClient(retryClient.StandardClient()),
274+
)
257275
if err != nil {
258276
return nil, fmt.Errorf("failed to create api client: %w", err)
259277
}
260-
u.apiClient = ac
261278
}
262279

263280
// Set team name from configuration if not provided
@@ -289,11 +306,12 @@ func (u *BatchUpdater) setupAWSMarketplace() error {
289306
u.batchLimit = 1000000000
290307

291308
u.minTimeBetweenFlushes = 1 * time.Minute
309+
u.maxRetries = max(u.maxRetries, marketplaceMinRetries)
292310
u.backgroundUpdater()
293311

294312
_, err = u.awsMarketplaceClient.MeterUsage(ctx, &marketplacemetering.MeterUsageInput{
295313
ProductCode: aws.String(awsMarketplaceProductCode()),
296-
Timestamp: aws.Time(time.Now()),
314+
Timestamp: aws.Time(u.timeFunc()),
297315
UsageDimension: aws.String("rows"),
298316
UsageQuantity: aws.Int32(int32(0)),
299317
DryRun: aws.Bool(true),
@@ -486,13 +504,13 @@ func (u *BatchUpdater) backgroundUpdater() {
486504
}
487505
// If we are using AWS Marketplace, we need to round down to the nearest 1000
488506
// Only on the last update, will we round up to the nearest 1000
489-
// This will allow us to not over charge the customer by rounding on each batch
507+
// This will allow us to not overcharge the customer by rounding on each batch
490508
if u.awsMarketplaceClient != nil {
491509
totals = roundDown(totals, 1000)
492510
}
493511

494512
if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil {
495-
log.Warn().Err(err).Msg("failed to update usage")
513+
u.logger.Warn().Err(err).Msg("failed to update usage")
496514
continue
497515
}
498516
u.subtractTableUsage(tables, totals)
@@ -510,12 +528,12 @@ func (u *BatchUpdater) backgroundUpdater() {
510528
}
511529
// If we are using AWS Marketplace, we need to round down to the nearest 1000
512530
// Only on the last update, will we round up to the nearest 1000
513-
// This will allow us to not over charge the customer by rounding on each batch
531+
// This will allow us to not overcharge the customer by rounding on each batch
514532
if u.awsMarketplaceClient != nil {
515533
totals = roundDown(totals, 1000)
516534
}
517535
if err := u.updateUsageWithRetryAndBackoff(ctx, totals, tables); err != nil {
518-
log.Warn().Err(err).Msg("failed to update usage")
536+
u.logger.Warn().Err(err).Msg("failed to update usage")
519537
continue
520538
}
521539
u.subtractTableUsage(tables, totals)
@@ -573,60 +591,70 @@ func (u *BatchUpdater) reportUsageToAWSMarketplace(ctx context.Context, rows uin
573591
// Each product is given a unique product code when it is listed in AWS Marketplace
574592
// in the future we can have multiple product codes for container or AMI based listings
575593
ProductCode: aws.String(awsMarketplaceProductCode()),
576-
Timestamp: aws.Time(time.Now()),
594+
Timestamp: aws.Time(u.timeFunc()),
577595
UsageDimension: aws.String("rows"),
578596
UsageAllocations: usage,
579597
UsageQuantity: aws.Int32(int32(rows)),
580598
})
581599
if err != nil {
582-
return fmt.Errorf("failed to update usage with : %w", err)
600+
return fmt.Errorf("failed to update usage: %w", err)
583601
}
584602
return nil
585603
}
586604

587-
func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error {
605+
func (u *BatchUpdater) updateMarketplaceUsage(ctx context.Context, rows uint32) error {
606+
var lastErr error
588607
for retry := 0; retry < u.maxRetries; retry++ {
589-
u.logger.Debug().Str("url", u.url).Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage")
590-
queryStartTime := time.Now()
608+
u.logger.Debug().Int("try", retry).Int("max_retries", u.maxRetries).Uint32("rows", rows).Msg("updating usage")
591609

592-
// If the AWS Marketplace client is set, use it to track usage
593-
if u.awsMarketplaceClient != nil {
594-
return u.reportUsageToAWSMarketplace(ctx, rows)
595-
}
596-
payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{
597-
RequestId: uuid.New(),
598-
PluginTeam: u.pluginMeta.Team,
599-
PluginKind: u.pluginMeta.Kind,
600-
PluginName: u.pluginMeta.Name,
601-
Rows: int(rows),
610+
lastErr = u.reportUsageToAWSMarketplace(ctx, rows)
611+
if lastErr == nil {
612+
u.logger.Debug().Int("try", retry).Uint32("rows", rows).Msg("usage updated")
613+
return nil
602614
}
603615

604-
if len(tables) > 0 {
605-
payload.Tables = &tables
616+
var de *types.DuplicateRequestException
617+
if !errors.As(lastErr, &de) {
618+
return fmt.Errorf("failed to update usage: %w", lastErr)
606619
}
620+
u.logger.Debug().Err(lastErr).Int("try", retry).Uint32("rows", rows).Msg("usage update failed due to duplicate request")
607621

608-
resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload)
609-
if err != nil {
610-
return fmt.Errorf("failed to update usage: %w", err)
611-
}
612-
if resp.StatusCode() >= 200 && resp.StatusCode() < 300 {
613-
u.logger.Debug().Str("url", u.url).Int("try", retry).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated")
614-
u.lastUpdateTime = time.Now().UTC()
615-
if resp.HTTPResponse != nil {
616-
u.updateConfigurationFromHeaders(resp.HTTPResponse.Header)
617-
}
618-
return nil
619-
}
622+
jitter := time.Duration(rand.Intn(1000)) * time.Millisecond
623+
time.Sleep(marketplaceDuplicateWaitTime + jitter)
624+
}
625+
return fmt.Errorf("failed to update usage: max retries exceeded: %w", lastErr)
626+
}
620627

621-
retryDuration, err := u.calculateRetryDuration(resp.StatusCode(), resp.HTTPResponse.Header, queryStartTime, retry)
622-
if err != nil {
623-
return fmt.Errorf("failed to calculate retry duration: %w", err)
624-
}
625-
if retryDuration > 0 {
626-
time.Sleep(retryDuration)
628+
func (u *BatchUpdater) updateUsageWithRetryAndBackoff(ctx context.Context, rows uint32, tables []cqapi.UsageIncreaseTablesInner) error {
629+
// If the AWS Marketplace client is set, use it to track usage
630+
if u.awsMarketplaceClient != nil {
631+
return u.updateMarketplaceUsage(ctx, rows)
632+
}
633+
634+
u.logger.Debug().Str("url", u.url).Uint32("rows", rows).Msg("updating usage")
635+
payload := cqapi.IncreaseTeamPluginUsageJSONRequestBody{
636+
RequestId: uuid.New(),
637+
PluginTeam: u.pluginMeta.Team,
638+
PluginKind: u.pluginMeta.Kind,
639+
PluginName: u.pluginMeta.Name,
640+
Rows: int(rows),
641+
}
642+
643+
if len(tables) > 0 {
644+
payload.Tables = &tables
645+
}
646+
647+
resp, err := u.apiClient.IncreaseTeamPluginUsageWithResponse(ctx, u.teamName, payload)
648+
if err == nil && resp.StatusCode() >= 200 && resp.StatusCode() < 300 {
649+
u.logger.Debug().Str("url", u.url).Int("status_code", resp.StatusCode()).Uint32("rows", rows).Msg("usage updated")
650+
u.lastUpdateTime = u.timeFunc().UTC()
651+
if resp.HTTPResponse != nil {
652+
u.updateConfigurationFromHeaders(resp.HTTPResponse.Header)
627653
}
654+
return nil
628655
}
629-
return fmt.Errorf("failed to update usage: max retries exceeded")
656+
657+
return fmt.Errorf("failed to update usage: %w", err)
630658
}
631659

632660
// updateConfigurationFromHeaders updates the configuration based on the headers returned by the API
@@ -663,33 +691,6 @@ func (u *BatchUpdater) updateConfigurationFromHeaders(header http.Header) {
663691
}
664692
}
665693

666-
// calculateRetryDuration calculates the duration to sleep relative to the query start time before retrying an update
667-
func (u *BatchUpdater) calculateRetryDuration(statusCode int, headers http.Header, queryStartTime time.Time, retry int) (time.Duration, error) {
668-
if !retryableStatusCode(statusCode) {
669-
return 0, fmt.Errorf("non-retryable status code: %d", statusCode)
670-
}
671-
672-
// Check if we have a retry-after header
673-
retryAfter := headers.Get("Retry-After")
674-
if retryAfter != "" {
675-
retryDelay, err := time.ParseDuration(retryAfter + "s")
676-
if err != nil {
677-
return 0, fmt.Errorf("failed to parse retry-after header: %w", err)
678-
}
679-
return retryDelay, nil
680-
}
681-
682-
// Calculate exponential backoff
683-
baseRetry := min(time.Duration(1<<retry)*time.Second, u.maxWaitTime)
684-
jitter := time.Duration(rand.Intn(1000)) * time.Millisecond
685-
retryDelay := baseRetry + jitter
686-
return retryDelay - time.Since(queryStartTime), nil
687-
}
688-
689-
func retryableStatusCode(statusCode int) bool {
690-
return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable
691-
}
692-
693694
func (u *BatchUpdater) getTeamNameByTokenType(tokenType auth.TokenType) (string, error) {
694695
switch tokenType {
695696
case auth.BearerToken:

0 commit comments

Comments
 (0)