@@ -2,6 +2,7 @@ package aws
22
33import (
44 "context"
5+ "fmt"
56 "sync/atomic"
67 "time"
78
@@ -24,11 +25,13 @@ type CredentialsCacheOptions struct {
2425 // If ExpiryWindow is 0 or less it will be ignored.
2526 ExpiryWindow time.Duration
2627
27- // ExpiryWindowJitterFrac provides a mechanism for randomizing the expiration of credentials
28- // within the configured ExpiryWindow by a random percentage. Valid values are between 0.0 and 1.0.
28+ // ExpiryWindowJitterFrac provides a mechanism for randomizing the
29+ // expiration of credentials within the configured ExpiryWindow by a random
30+ // percentage. Valid values are between 0.0 and 1.0.
2931 //
30- // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac is 0.5 then credentials will be set to
31- // expire between 30 to 60 seconds prior to their actual expiration time.
32+ // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
33+ // is 0.5 then credentials will be set to expire between 30 to 60 seconds
34+ // prior to their actual expiration time.
3235 //
3336 // If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
3437 // If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
@@ -39,17 +42,29 @@ type CredentialsCacheOptions struct {
3942
4043// CredentialsCache provides caching and concurrency safe credentials retrieval
4144// via the provider's retrieve method.
45+ //
46+ // CredentialsCache will look for optional interfaces on the Provider to adjust
47+ // how the credential cache handles credentials caching.
48+ //
49+ // * HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
50+ // credential refresh failures. This could return an updated Credentials
51+ // value, or attempt another means of retrieving credentials.
52+ //
53+ // * AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
54+ // credentials Expires is modified. This could modify how the Credentials
55+ // Expires is adjusted based on the CredentialsCache ExpiryWindow option.
56+ // Such as providing a floor not to reduce the Expires below.
4257type CredentialsCache struct {
43- // provider is the CredentialProvider implementation to be wrapped by the CredentialCache.
4458 provider CredentialsProvider
4559
4660 options CredentialsCacheOptions
4761 creds atomic.Value
4862 sf singleflight.Group
4963}
5064
51- // NewCredentialsCache returns a CredentialsCache that wraps provider. Provider is expected to not be nil. A variadic
52- // list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for
65+ // NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
66+ // is expected to not be nil. A variadic list of one or more functions can be
67+ // provided to modify the CredentialsCache configuration. This allows for
5368// configuration of credential expiry window and jitter.
5469func NewCredentialsCache (provider CredentialsProvider , optFns ... func (options * CredentialsCacheOptions )) * CredentialsCache {
5570 options := CredentialsCacheOptions {}
@@ -81,8 +96,8 @@ func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *C
8196//
8297// Returns and error if the provider's retrieve method returns an error.
8398func (p * CredentialsCache ) Retrieve (ctx context.Context ) (Credentials , error ) {
84- if creds := p .getCreds (); creds != nil {
85- return * creds , nil
99+ if creds , ok := p .getCreds (); ok && ! creds . Expired () {
100+ return creds , nil
86101 }
87102
88103 resCh := p .sf .DoChan ("" , func () (interface {}, error ) {
@@ -97,43 +112,107 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
97112}
98113
99114func (p * CredentialsCache ) singleRetrieve (ctx context.Context ) (interface {}, error ) {
100- if creds := p .getCreds (); creds != nil {
101- return * creds , nil
115+ currCreds , ok := p .getCreds ()
116+ if ok && ! currCreds .Expired () {
117+ return currCreds , nil
118+ }
119+
120+ newCreds , err := p .provider .Retrieve (ctx )
121+ if err != nil {
122+ handleFailToRefresh := defaultHandleFailToRefresh
123+ if cs , ok := p .provider .(HandleFailRefreshCredentialsCacheStrategy ); ok {
124+ handleFailToRefresh = cs .HandleFailToRefresh
125+ }
126+ newCreds , err = handleFailToRefresh (ctx , currCreds , err )
127+ if err != nil {
128+ return Credentials {}, fmt .Errorf ("failed to refresh cached credentials, %w" , err )
129+ }
102130 }
103131
104- creds , err := p .provider . Retrieve ( ctx )
105- if err == nil {
106- if creds . CanExpire {
107- randFloat64 , err := sdkrand . CryptoRandFloat64 ()
108- if err != nil {
109- return Credentials {}, err
110- }
111- jitter := time . Duration ( randFloat64 * p . options . ExpiryWindowJitterFrac * float64 ( p . options . ExpiryWindow ))
112- creds . Expires = creds . Expires . Add ( - ( p . options . ExpiryWindow - jitter ) )
132+ if newCreds . CanExpire && p .options . ExpiryWindow > 0 {
133+ adjustExpiresBy := defaultAdjustExpiresBy
134+ if cs , ok := p . provider .( AdjustExpiresByCredentialsCacheStrategy ); ok {
135+ adjustExpiresBy = cs . AdjustExpiresBy
136+ }
137+
138+ randFloat64 , err := sdkrand . CryptoRandFloat64 ()
139+ if err != nil {
140+ return Credentials {}, fmt . Errorf ( "failed to get random provider, %w" , err )
113141 }
114142
115- p .creds .Store (& creds )
143+ var jitter time.Duration
144+ if p .options .ExpiryWindowJitterFrac > 0 {
145+ jitter = time .Duration (randFloat64 *
146+ p .options .ExpiryWindowJitterFrac * float64 (p .options .ExpiryWindow ))
147+ }
148+
149+ newCreds , err = adjustExpiresBy (newCreds , - (p .options .ExpiryWindow - jitter ))
150+ if err != nil {
151+ return Credentials {}, fmt .Errorf ("failed to adjust credentials expires, %w" , err )
152+ }
116153 }
117154
118- return creds , err
155+ p .creds .Store (& newCreds )
156+ return newCreds , nil
119157}
120158
121- func (p * CredentialsCache ) getCreds () * Credentials {
159+ // getCreds returns the currently stored credentials and true. Returning false
160+ // if no credentials were stored.
161+ func (p * CredentialsCache ) getCreds () (Credentials , bool ) {
122162 v := p .creds .Load ()
123163 if v == nil {
124- return nil
164+ return Credentials {}, false
125165 }
126166
127167 c := v .(* Credentials )
128- if c != nil && c . HasKeys () && ! c .Expired () {
129- return c
168+ if c == nil || ! c .HasKeys () {
169+ return Credentials {}, false
130170 }
131171
132- return nil
172+ return * c , true
133173}
134174
135175// Invalidate will invalidate the cached credentials. The next call to Retrieve
136176// will cause the provider's Retrieve method to be called.
137177func (p * CredentialsCache ) Invalidate () {
138178 p .creds .Store ((* Credentials )(nil ))
139179}
180+
181+ // HandleFailRefreshCredentialsCacheStrategy is an interface for
182+ // CredentialsCache to allow CredentialsProvider how failed to refresh
183+ // credentials is handled.
184+ type HandleFailRefreshCredentialsCacheStrategy interface {
185+ // Given the previously cached Credentials, if any, and refresh error, may
186+ // returns new or modified set of Credentials, or error.
187+ //
188+ // Credential caches may use default implementation if nil.
189+ HandleFailToRefresh (context.Context , Credentials , error ) (Credentials , error )
190+ }
191+
192+ // defaultHandleFailToRefresh returns the passed in error.
193+ func defaultHandleFailToRefresh (ctx context.Context , _ Credentials , err error ) (Credentials , error ) {
194+ return Credentials {}, err
195+ }
196+
197+ // AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
198+ // to allow CredentialsProvider to intercept adjustments to Credentials expiry
199+ // based on expectations and use cases of CredentialsProvider.
200+ //
201+ // Credential caches may use default implementation if nil.
202+ type AdjustExpiresByCredentialsCacheStrategy interface {
203+ // Given a Credentials as input, applying any mutations and
204+ // returning the potentially updated Credentials, or error.
205+ AdjustExpiresBy (Credentials , time.Duration ) (Credentials , error )
206+ }
207+
208+ // defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
209+ // and returns the updated credentials value. If Credentials value's CanExpire
210+ // is false, the passed in credentials are returned unchanged.
211+ func defaultAdjustExpiresBy (creds Credentials , dur time.Duration ) (Credentials , error ) {
212+ if ! creds .CanExpire {
213+ return creds , nil
214+ }
215+
216+ creds .Expires = creds .Expires .Add (dur )
217+ return creds , nil
218+ }
0 commit comments