Skip to content

Commit dbdbb82

Browse files
authored
aws: Add singleflight support to SafeCredentialsProvider (#503)
1 parent 7e732f1 commit dbdbb82

File tree

14 files changed

+400
-38
lines changed

14 files changed

+400
-38
lines changed

CHANGELOG_PENDING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ SDK Features
3131

3232
SDK Enhancements
3333
---
34+
* `aws`: Add grouping of concurrent refresh of credentials ([#503](https://github.com/aws/aws-sdk-go-v2/pull/503)
35+
* Concurrent calls to `Retrieve` are now grouped in order to prevent numerous synchronous calls to refresh the credentials. Replacing the mutex with a singleflight reduces the overall amount of time request signatures need to wait while retrieving credentials. This is improvement becomes pronounced when many requests are made concurrently.
3436
* `service/s3/s3manager`: Improve memory allocation behavior by replacing sync.Pool with custom pool implementation
3537
* Improves memory allocations that occur when the provided `io.Reader` to upload does not satisfy both the `io.ReaderAt` and `io.ReadSeeker` interfaces.
3638

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMS
77
LINTIGNOREDEPS='vendor/.+\.go'
88
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
99
LINTIGNOREENDPOINTS='aws/endpoints/defaults.go:.+(method|const) .+ should be '
10+
LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type'
1011
UNIT_TEST_TAGS="example codegen awsinclude"
1112
ALL_TAGS="example codegen awsinclude integration perftest sdktool"
1213

@@ -145,7 +146,16 @@ verify: lint vet sdkv1check
145146
lint:
146147
@echo "go lint SDK and vendor packages"
147148
@lint=`golint ./...`; \
148-
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS}`; \
149+
dolint=`echo "$$lint" | grep -E -v \
150+
-e ${LINTIGNOREDOC} \
151+
-e ${LINTIGNORECONST} \
152+
-e ${LINTIGNORESTUTTER} \
153+
-e ${LINTIGNOREINFLECT} \
154+
-e ${LINTIGNOREDEPS} \
155+
-e ${LINTIGNOREINFLECTS3UPLOAD} \
156+
-e ${LINTIGNOREPKGCOMMENT} \
157+
-e ${LINTIGNOREENDPOINTS} \
158+
-e ${LINTIGNORESINGLEFIGHT}`; \
149159
echo "$$dolint"; \
150160
if [ "$$dolint" != "" ]; then exit 1; fi
151161

aws/chain_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ func NewChainProvider(providers []CredentialsProvider) *ChainProvider {
6161
//
6262
// If a provider is found it will be cached and any calls to IsExpired()
6363
// will return the expired state of the cached provider.
64-
func (c *ChainProvider) retrieveFn(ctx context.Context) (Credentials, error) {
64+
func (c *ChainProvider) retrieveFn() (Credentials, error) {
6565
var errs []error
6666
for _, p := range c.Providers {
67-
creds, err := p.Retrieve(ctx)
67+
creds, err := p.Retrieve(context.Background())
6868
if err == nil {
6969
return creds, nil
7070
}

aws/credentials.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package aws
33
import (
44
"context"
55
"math"
6-
"sync"
76
"sync/atomic"
87
"time"
98

9+
"github.com/aws/aws-sdk-go-v2/aws/awserr"
1010
"github.com/aws/aws-sdk-go-v2/internal/sdk"
11+
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
1112
)
1213

1314
// NeverExpire is the time identifier used when a credential provider's
@@ -83,10 +84,10 @@ type CredentialsProvider interface {
8384
// SafeCredentialsProvider provides caching and concurrency safe credentials
8485
// retrieval via the RetrieveFn.
8586
type SafeCredentialsProvider struct {
86-
RetrieveFn func(ctx context.Context) (Credentials, error)
87+
RetrieveFn func() (Credentials, error)
8788

8889
creds atomic.Value
89-
m sync.Mutex
90+
sf singleflight.Group
9091
}
9192

9293
// Retrieve returns the credentials. If the credentials have already been
@@ -99,21 +100,27 @@ func (p *SafeCredentialsProvider) Retrieve(ctx context.Context) (Credentials, er
99100
return *creds, nil
100101
}
101102

102-
p.m.Lock()
103-
defer p.m.Unlock()
103+
resCh := p.sf.DoChan("", p.singleRetrieve)
104+
select {
105+
case res := <-resCh:
106+
return res.Val.(Credentials), res.Err
107+
case <-ctx.Done():
108+
return Credentials{}, awserr.New("RequestCanceled",
109+
"request context canceled", ctx.Err())
110+
}
111+
}
104112

105-
// Make sure another goroutine didn't already update the credentials.
113+
func (p *SafeCredentialsProvider) singleRetrieve() (interface{}, error) {
106114
if creds := p.getCreds(); creds != nil {
107115
return *creds, nil
108116
}
109117

110-
creds, err := p.RetrieveFn(ctx)
111-
if err != nil {
112-
return Credentials{}, err
118+
creds, err := p.RetrieveFn()
119+
if err == nil {
120+
p.creds.Store(&creds)
113121
}
114-
p.creds.Store(&creds)
115122

116-
return creds, nil
123+
return creds, err
117124
}
118125

119126
func (p *SafeCredentialsProvider) getCreds() *Credentials {

aws/credentials_bench_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
)
1111

1212
func BenchmarkSafeCredentialsProvider_Retrieve(b *testing.B) {
13-
retrieveFn := func(ctx context.Context) (Credentials, error) {
13+
retrieveFn := func() (Credentials, error) {
1414
return Credentials{
1515
AccessKeyID: "key",
1616
SecretAccessKey: "secret",
@@ -45,7 +45,7 @@ func BenchmarkSafeCredentialsProvider_Retrieve(b *testing.B) {
4545
}
4646

4747
func BenchmarkSafeCredentialsProvider_Retrieve_Invalidate(b *testing.B) {
48-
retrieveFn := func(ctx context.Context) (Credentials, error) {
48+
retrieveFn := func() (Credentials, error) {
4949
time.Sleep(time.Millisecond)
5050
return Credentials{
5151
AccessKeyID: "key",

aws/credentials_test.go

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"math/rand"
77
"sync"
8+
"sync/atomic"
89
"testing"
910
"time"
1011

@@ -42,7 +43,7 @@ func TestSafeCredentialsProvider_Cache(t *testing.T) {
4243

4344
var called bool
4445
p := &SafeCredentialsProvider{
45-
RetrieveFn: func(ctx context.Context) (Credentials, error) {
46+
RetrieveFn: func() (Credentials, error) {
4647
if called {
4748
t.Fatalf("expect RetrieveFn to only be called once")
4849
}
@@ -108,7 +109,7 @@ func TestSafeCredentialsProvider_Expires(t *testing.T) {
108109
for _, c := range cases {
109110
var called int
110111
p := &SafeCredentialsProvider{
111-
RetrieveFn: func(ctx context.Context) (Credentials, error) {
112+
RetrieveFn: func() (Credentials, error) {
112113
called++
113114
return c.Creds(), nil
114115
},
@@ -132,7 +133,7 @@ func TestSafeCredentialsProvider_Expires(t *testing.T) {
132133

133134
func TestSafeCredentialsProvider_Error(t *testing.T) {
134135
p := &SafeCredentialsProvider{
135-
RetrieveFn: func(ctx context.Context) (Credentials, error) {
136+
RetrieveFn: func() (Credentials, error) {
136137
return Credentials{}, fmt.Errorf("failed")
137138
},
138139
}
@@ -156,7 +157,7 @@ func TestSafeCredentialsProvider_Race(t *testing.T) {
156157
}
157158
var called bool
158159
p := &SafeCredentialsProvider{
159-
RetrieveFn: func(ctx context.Context) (Credentials, error) {
160+
RetrieveFn: func() (Credentials, error) {
160161
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
161162
if called {
162163
t.Fatalf("expect RetrieveFn only called once")
@@ -186,3 +187,41 @@ func TestSafeCredentialsProvider_Race(t *testing.T) {
186187

187188
wg.Wait()
188189
}
190+
191+
type stubSafeProviderConcurrent struct {
192+
SafeCredentialsProvider
193+
called uint32
194+
done chan struct{}
195+
}
196+
197+
func TestSafeProviderRetrieveConcurrent(t *testing.T) {
198+
stub := &stubSafeProviderConcurrent{
199+
done: make(chan struct{}),
200+
}
201+
202+
stub.RetrieveFn = func() (Credentials, error) {
203+
atomic.AddUint32(&stub.called, 1)
204+
<-stub.done
205+
return Credentials{
206+
AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
207+
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
208+
}, nil
209+
}
210+
211+
done := make(chan struct{})
212+
for i := 0; i < 2; i++ {
213+
go func() {
214+
stub.Retrieve(context.Background())
215+
done <- struct{}{}
216+
}()
217+
}
218+
219+
// Validates that a single call to Retrieve is shared between two calls to Get
220+
stub.done <- struct{}{}
221+
<-done
222+
<-done
223+
224+
if e, a := uint32(1), atomic.LoadUint32(&stub.called); e != a {
225+
t.Errorf("expected %v, got %v", e, a)
226+
}
227+
}

aws/ec2rolecreds/provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ func New(client *ec2metadata.Client, options ...func(*ProviderOptions)) *Provide
6868
// Retrieve retrieves credentials from the EC2 service.
6969
// Error will be returned if the request fails, or unable to extract
7070
// the desired credentials.
71-
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
72-
credsList, err := requestCredList(ctx, p.client)
71+
func (p *Provider) retrieveFn() (aws.Credentials, error) {
72+
credsList, err := requestCredList(context.Background(), p.client)
7373
if err != nil {
7474
return aws.Credentials{}, err
7575
}
@@ -80,7 +80,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
8080
}
8181
credsName := credsList[0]
8282

83-
roleCreds, err := requestCred(ctx, p.client, credsName)
83+
roleCreds, err := requestCred(context.Background(), p.client, credsName)
8484
if err != nil {
8585
return aws.Credentials{}, err
8686
}

aws/endpointcreds/provider.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
package endpointcreds
3131

3232
import (
33-
"context"
3433
"encoding/json"
3534
"time"
3635

@@ -99,8 +98,8 @@ func New(cfg aws.Config, options ...func(*ProviderOptions)) *Provider {
9998

10099
// Retrieve will attempt to request the credentials from the endpoint the Provider
101100
// was configured for. And error will be returned if the retrieval fails.
102-
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
103-
resp, err := p.getCredentials(ctx)
101+
func (p *Provider) retrieveFn() (aws.Credentials, error) {
102+
resp, err := p.getCredentials()
104103
if err != nil {
105104
return aws.Credentials{},
106105
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
@@ -133,15 +132,14 @@ type errorOutput struct {
133132
Message string `json:"message"`
134133
}
135134

136-
func (p *Provider) getCredentials(ctx context.Context) (*getCredentialsOutput, error) {
135+
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
137136
op := &aws.Operation{
138137
Name: "GetCredentials",
139138
HTTPMethod: "GET",
140139
}
141140

142141
out := &getCredentialsOutput{}
143142
req := p.client.NewRequest(op, nil, out)
144-
req.SetContext(ctx)
145143
req.HTTPRequest.Header.Set("Accept", "application/json")
146144
if authToken := p.options.AuthorizationToken; len(authToken) != 0 {
147145
req.HTTPRequest.Header.Set("Authorization", authToken)

aws/processcreds/provider.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ type credentialProcessResponse struct {
200200
}
201201

202202
// retrieveFn executes the 'credential_process' and returns the credentials.
203-
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
204-
out, err := p.executeCredentialProcess(ctx)
203+
func (p *Provider) retrieveFn() (aws.Credentials, error) {
204+
out, err := p.executeCredentialProcess()
205205
if err != nil {
206206
return aws.Credentials{Source: ProviderName}, err
207207
}
@@ -253,7 +253,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
253253
}
254254

255255
// prepareCommand prepares the command to be executed.
256-
func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context.CancelFunc, error) {
256+
func (p *Provider) prepareCommand() (context.Context, context.CancelFunc, error) {
257257

258258
var cmdArgs []string
259259
if runtime.GOOS == "windows" {
@@ -278,7 +278,7 @@ func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context
278278
}
279279
}
280280

281-
timeoutCtx, cancelFunc := context.WithTimeout(ctx, p.options.Timeout)
281+
timeoutCtx, cancelFunc := context.WithTimeout(context.Background(), p.options.Timeout)
282282

283283
cmdArgs = append(cmdArgs, p.originalCommand...)
284284
p.command = exec.CommandContext(timeoutCtx, cmdArgs[0], cmdArgs[1:]...)
@@ -289,8 +289,8 @@ func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context
289289

290290
// executeCredentialProcess starts the credential process on the OS and
291291
// returns the results or an error.
292-
func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
293-
ctx, cancelFunc, err := p.prepareCommand(ctx)
292+
func (p *Provider) executeCredentialProcess() ([]byte, error) {
293+
ctx, cancelFunc, err := p.prepareCommand()
294294
if err != nil {
295295
return nil, err
296296
}

aws/stscreds/provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func NewAssumeRoleProvider(client AssumeRoler, roleARN string, options ...func(*
211211
}
212212

213213
// Retrieve generates a new set of temporary credentials using STS.
214-
func (p *AssumeRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
214+
func (p *AssumeRoleProvider) retrieveFn() (aws.Credentials, error) {
215215
// Apply defaults where parameters are not set.
216216
if len(p.options.RoleSessionName) == 0 {
217217
// Try to work out a role name that will hopefully end up unique.
@@ -246,7 +246,7 @@ func (p *AssumeRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, e
246246
}
247247

248248
req := p.client.AssumeRoleRequest(input)
249-
resp, err := req.Send(ctx)
249+
resp, err := req.Send(context.Background())
250250
if err != nil {
251251
return aws.Credentials{Source: ProviderName}, err
252252
}

0 commit comments

Comments
 (0)