|
13 | 13 | using Microsoft.Identity.Client; |
14 | 14 | using Microsoft.Identity.Client.Core; |
15 | 15 | using Microsoft.Identity.Client.Http.Retry; |
| 16 | +using Microsoft.Identity.Client.OAuth2; |
16 | 17 | using Microsoft.Identity.Test.Common; |
17 | 18 | using Microsoft.Identity.Test.Common.Core.Helpers; |
18 | 19 | using Microsoft.Identity.Test.Common.Core.Mocks; |
@@ -536,6 +537,94 @@ public async Task TestSendPostWithRetryOnTimeoutFailureAsync() |
536 | 537 | } |
537 | 538 | } |
538 | 539 |
|
| 540 | + [TestMethod] |
| 541 | + [DataRow(true)] |
| 542 | + [DataRow(false)] |
| 543 | + public async Task TestCorrelationIdWithRetryOnTimeoutFailureAsync(bool addCorrelationId) |
| 544 | + { |
| 545 | + using (var httpManager = new MockHttpManager()) |
| 546 | + { |
| 547 | + // Simulate permanent errors (to trigger the maximum number of retries) |
| 548 | + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries |
| 549 | + for (int i = 0; i < Num500Errors; i++) |
| 550 | + { |
| 551 | + httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Post); |
| 552 | + } |
| 553 | + |
| 554 | + Guid correlationId = Guid.NewGuid(); |
| 555 | + var headers = new Dictionary<string, string>(); |
| 556 | + |
| 557 | + if (addCorrelationId) |
| 558 | + { |
| 559 | + headers.Add(OAuth2Header.CorrelationId, correlationId.ToString()); |
| 560 | + } |
| 561 | + |
| 562 | + var exc = await AssertException.TaskThrowsAsync<MsalServiceException>(() => |
| 563 | + httpManager.SendRequestAsync( |
| 564 | + new Uri(TestConstants.AuthorityHomeTenant + "oauth2/token"), |
| 565 | + headers: headers, |
| 566 | + body: new FormUrlEncodedContent(new Dictionary<string, string>()), |
| 567 | + method: HttpMethod.Post, |
| 568 | + logger: Substitute.For<ILoggerAdapter>(), |
| 569 | + doNotThrow: false, |
| 570 | + mtlsCertificate: null, |
| 571 | + validateServerCert: null, |
| 572 | + cancellationToken: default, |
| 573 | + retryPolicy: _stsRetryPolicy)) |
| 574 | + .ConfigureAwait(false); |
| 575 | + |
| 576 | + Assert.AreEqual(MsalError.RequestTimeout, exc.ErrorCode); |
| 577 | + |
| 578 | + if (addCorrelationId) |
| 579 | + { |
| 580 | + Assert.AreEqual($"Request to the endpoint timed out. CorrelationId: {correlationId.ToString()}", exc.Message); |
| 581 | + Assert.AreEqual(correlationId.ToString(), exc.CorrelationId); |
| 582 | + } |
| 583 | + else |
| 584 | + { |
| 585 | + Assert.AreEqual("Request to the endpoint timed out.", exc.Message); |
| 586 | + } |
| 587 | + } |
| 588 | + } |
| 589 | + |
| 590 | + [TestMethod] |
| 591 | + public async Task TestWithCorrelationId_RetryOnTimeoutFailureAsync() |
| 592 | + { |
| 593 | + // Arrange |
| 594 | + using (var httpManager = new MockHttpManager()) |
| 595 | + { |
| 596 | + httpManager.AddInstanceDiscoveryMockHandler(); |
| 597 | + |
| 598 | + // Simulate permanent errors (to trigger the maximum number of retries) |
| 599 | + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries |
| 600 | + for (int i = 0; i < Num500Errors; i++) |
| 601 | + { |
| 602 | + httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Post); |
| 603 | + } |
| 604 | + Guid correlationId = Guid.NewGuid(); |
| 605 | + |
| 606 | + var app = ConfidentialClientApplicationBuilder |
| 607 | + .Create(TestConstants.ClientId) |
| 608 | + .WithAuthority(TestConstants.AuthorityTestTenant) |
| 609 | + .WithHttpManager(httpManager) |
| 610 | + .WithClientSecret(TestConstants.ClientSecret) |
| 611 | + .Build(); |
| 612 | + |
| 613 | + var userAssertion = new UserAssertion(TestConstants.DefaultAccessToken); |
| 614 | + |
| 615 | + // Act |
| 616 | + var exc = await AssertException.TaskThrowsAsync<MsalServiceException>(() => |
| 617 | + app.AcquireTokenForClient(TestConstants.s_scope) |
| 618 | + .WithCorrelationId(correlationId) |
| 619 | + .ExecuteAsync()) |
| 620 | + .ConfigureAwait(false); |
| 621 | + |
| 622 | + // Assert |
| 623 | + Assert.AreEqual($"Request to the endpoint timed out. CorrelationId: {correlationId.ToString()}", exc.Message); |
| 624 | + Assert.AreEqual(correlationId.ToString(), exc.CorrelationId); |
| 625 | + } |
| 626 | + } |
| 627 | + |
539 | 628 | private class CapturingHandler : HttpMessageHandler |
540 | 629 | { |
541 | 630 | public HttpRequestMessage CapturedRequest { get; private set; } |
|
0 commit comments