Skip to content

Commit 0dd2599

Browse files
authored
Use underlying error when checking for context errors (Azure#20404)
* Use underlying error when checking for context errors There is a race between error propagation and closing a context's deadline channel. The result is that while the API that takes a context returns the correct error, context.Err() might return nil. See golang/go#31863 for more info. * further refinements * revert ordering of error check * add changelog entry
1 parent c341474 commit 0dd2599

8 files changed

+40
-19
lines changed

sdk/azidentity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
### Breaking Changes
1010

1111
### Bugs Fixed
12+
* Fixed an issue in `DefaultAzureCredential` that could cause the managed identity endpoint check to fail in rare circumstances.
1213

1314
### Other Changes
1415

sdk/azidentity/chained_token_credential.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
117117
err = newCredentialUnavailableError(c.name, msg)
118118
} else {
119119
res := getResponseFromError(err)
120-
err = newAuthenticationFailedError(c.name, msg, res)
120+
err = newAuthenticationFailedError(c.name, msg, res, err)
121121
}
122122
}
123123
return token, err

sdk/azidentity/chained_token_credential_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) {
113113

114114
func TestChainedTokenCredential_GetTokenFail(t *testing.T) {
115115
c := NewFakeCredential()
116-
c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil))
116+
c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil, nil))
117117
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c}, nil)
118118
if err != nil {
119119
t.Fatal(err)
@@ -158,7 +158,7 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenAuthenticationFailed(
158158
c2 := NewFakeCredential()
159159
c2.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("unavailableCredential2", "Unavailable expected error"))
160160
c3 := NewFakeCredential()
161-
c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil))
161+
c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil, nil))
162162
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c1, c2, c3}, nil)
163163
if err != nil {
164164
t.Fatal(err)
@@ -263,7 +263,7 @@ func TestChainedTokenCredential_Race(t *testing.T) {
263263
successFake := NewFakeCredential()
264264
successFake.SetResponse(azcore.AccessToken{Token: "*", ExpiresOn: time.Now().Add(time.Hour)}, nil)
265265
authFailFake := NewFakeCredential()
266-
authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil))
266+
authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil, nil))
267267
unavailableFake := NewFakeCredential()
268268
unavailableFake.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("", ""))
269269
tro := policy.TokenRequestOptions{Scopes: []string{liveTestScope}}

sdk/azidentity/default_azure_credential.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func (w *timeoutWrapper) GetToken(ctx context.Context, opts policy.TokenRequestO
182182
c, cancel := context.WithTimeout(ctx, w.timeout)
183183
defer cancel()
184184
tk, err = w.mic.GetToken(c, opts)
185-
if ce := c.Err(); errors.Is(ce, context.DeadlineExceeded) {
185+
if isAuthFailedDueToContext(err) {
186186
err = newCredentialUnavailableError(credNameManagedIdentity, "managed identity timed out")
187187
} else {
188188
// some managed identity implementation is available, so don't apply the timeout to future calls
@@ -193,3 +193,15 @@ func (w *timeoutWrapper) GetToken(ctx context.Context, opts policy.TokenRequestO
193193
}
194194
return tk, err
195195
}
196+
197+
// unwraps nested AuthenticationFailedErrors to get the root error
198+
func isAuthFailedDueToContext(err error) bool {
199+
for {
200+
var authFailedErr *AuthenticationFailedError
201+
if !errors.As(err, &authFailedErr) {
202+
break
203+
}
204+
err = authFailedErr.err
205+
}
206+
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
207+
}

sdk/azidentity/default_azure_credential_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ type delayPolicy struct {
171171
}
172172

173173
func (p *delayPolicy) Do(req *policy.Request) (resp *http.Response, err error) {
174-
time.Sleep(p.delay)
174+
if p.delay > 0 {
175+
select {
176+
case <-req.Raw().Context().Done():
177+
return nil, req.Raw().Context().Err()
178+
case <-time.After(p.delay):
179+
// delay has elapsed, continue on
180+
}
181+
}
175182
return req.Next()
176183
}
177184

@@ -180,7 +187,7 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) {
180187
defer close()
181188
srv.SetResponse(mock.WithBody(accessTokenRespSuccess))
182189

183-
timeout := 5 * time.Millisecond
190+
timeout := 100 * time.Millisecond
184191
dp := delayPolicy{2 * timeout}
185192
mic, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
186193
ClientOptions: policy.ClientOptions{
@@ -201,7 +208,7 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) {
201208
// expecting credentialUnavailableError because delay exceeds the wrapper's timeout
202209
_, err = chain.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
203210
if _, ok := err.(*credentialUnavailableError); !ok {
204-
t.Fatalf("expected credentialUnavailableError, got %v", err)
211+
t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err)
205212
}
206213
}
207214

sdk/azidentity/errors.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,16 @@ type AuthenticationFailedError struct {
3939

4040
credType string
4141
message string
42+
err error
4243
}
4344

44-
func newAuthenticationFailedError(credType string, message string, resp *http.Response) error {
45-
return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp}
45+
func newAuthenticationFailedError(credType string, message string, resp *http.Response, err error) error {
46+
return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp, err: err}
4647
}
4748

4849
func newAuthenticationFailedErrorFromMSALError(credType string, err error) error {
4950
res := getResponseFromError(err)
50-
return newAuthenticationFailedError(credType, err.Error(), res)
51+
return newAuthenticationFailedError(credType, err.Error(), res, err)
5152
}
5253

5354
// Error implements the error interface. Note that the message contents are not contractual and can change over time.

sdk/azidentity/errors_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func TestAuthenticationFailedErrorInterface(t *testing.T) {
3131
Body: io.NopCloser(bytes.NewBufferString(resBodyString)),
3232
Request: req,
3333
}
34-
err = newAuthenticationFailedError(credNameAzureCLI, "error message", res)
34+
err = newAuthenticationFailedError(credNameAzureCLI, "error message", res, nil)
3535
if e, ok := err.(*AuthenticationFailedError); ok {
3636
if e.RawResponse == nil {
3737
t.Fatal("expected a non-nil RawResponse")

sdk/azidentity/managed_identity_client.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
168168

169169
resp, err := c.pipeline.Do(msg)
170170
if err != nil {
171-
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil)
171+
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err)
172172
}
173173

174174
if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
@@ -177,12 +177,12 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi
177177

178178
if c.msiType == msiTypeIMDS && resp.StatusCode == 400 {
179179
if id != nil {
180-
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp)
180+
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil)
181181
}
182182
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, "no default identity is assigned to this resource")
183183
}
184184

185-
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp)
185+
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp, nil)
186186
}
187187

188188
func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.AccessToken, error) {
@@ -210,10 +210,10 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac
210210
if expiresOn, err := strconv.Atoi(v); err == nil {
211211
return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil
212212
}
213-
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res)
213+
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res, nil)
214214
default:
215215
msg := fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)
216-
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res)
216+
return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res, nil)
217217
}
218218
}
219219

@@ -228,7 +228,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
228228
key, err := c.getAzureArcSecretKey(ctx, scopes)
229229
if err != nil {
230230
msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
231-
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil)
231+
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
232232
}
233233
return c.createAzureArcAuthRequest(ctx, id, scopes, key)
234234
case msiTypeServiceFabric:
@@ -322,7 +322,7 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
322322
// of the secret key file. Any other status code indicates an error in the request.
323323
if response.StatusCode != 401 {
324324
msg := fmt.Sprintf("expected a 401 response, received %d", response.StatusCode)
325-
return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response)
325+
return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response, nil)
326326
}
327327
header := response.Header.Get("WWW-Authenticate")
328328
if len(header) == 0 {

0 commit comments

Comments
 (0)