44using System ;
55using System . Collections . Generic ;
66using System . IO ;
7+ using System . Net ;
78using System . Net . Http ;
89using System . Threading ;
910using System . Threading . Tasks ;
1011using Microsoft . Identity . Client ;
12+ using Microsoft . Identity . Client . Http . Retry ;
1113using Microsoft . Identity . Client . Instance . Discovery ;
1214using Microsoft . Identity . Client . Internal ;
1315using Microsoft . Identity . Client . Region ;
1416using Microsoft . Identity . Client . TelemetryCore . Internal . Events ;
1517using Microsoft . Identity . Test . Common . Core . Mocks ;
18+ using Microsoft . Identity . Test . Unit . Helpers ;
1619using Microsoft . VisualStudio . TestTools . UnitTesting ;
1720
1821namespace Microsoft . Identity . Test . Unit . CoreTests
@@ -28,13 +31,15 @@ public class RegionDiscoveryProviderTests : TestBase
2831 private ApiEvent _apiEvent ;
2932 private CancellationTokenSource _userCancellationTokenSource ;
3033 private IRegionDiscoveryProvider _regionDiscoveryProvider ;
34+ private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory ( ) ;
3135
3236 [ TestInitialize ]
3337 public override void TestInitialize ( )
3438 {
3539 base . TestInitialize ( ) ;
3640
3741 _harness = base . CreateTestHarness ( ) ;
42+ _harness . ServiceBundle . Config . RetryPolicyFactory = _testRetryPolicyFactory ;
3843 _httpManager = _harness . HttpManager ;
3944 _userCancellationTokenSource = new CancellationTokenSource ( ) ;
4045 _testRequestContext = new RequestContext (
@@ -145,10 +150,18 @@ public async Task FetchRegionFromLocalImdsThenGetMetadataFromCacheAsync()
145150 ValidateInstanceMetadata ( regionalMetadata ) ;
146151 }
147152
148- [ TestMethod ]
149- public async Task SuccessfulResponseFromUserProvidedRegionAsync ( )
153+ [ DataTestMethod ]
154+ [ DataRow ( HttpStatusCode . NotFound , 0 , TestConstants . RegionAutoDetectNotFoundFailureMessage ) ] // No retries for 404 errors
155+ [ DataRow ( HttpStatusCode . InternalServerError , TestRegionDiscoveryRetryPolicy . NumRetries , TestConstants . RegionAutoDetectInternalServerErrorFailureMessage ) ]
156+ public async Task SuccessfulResponseFromUserProvidedRegionAsync (
157+ HttpStatusCode statusCode ,
158+ int expectedRetries ,
159+ string expectedFailureMessage )
150160 {
151- AddMockedResponse ( MockHelpers . CreateNullMessage ( System . Net . HttpStatusCode . NotFound ) ) ;
161+ for ( int i = 0 ; i < ( 1 + expectedRetries ) ; i ++ )
162+ {
163+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
164+ }
152165
153166 _testRequestContext . ServiceBundle . Config . AzureRegion = TestConstants . Region ;
154167
@@ -162,7 +175,10 @@ public async Task SuccessfulResponseFromUserProvidedRegionAsync()
162175 Assert . AreEqual ( TestConstants . Region , _testRequestContext . ApiEvent . RegionUsed ) ;
163176 Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
164177 Assert . AreEqual ( RegionOutcome . UserProvidedAutodetectionFailed , _testRequestContext . ApiEvent . RegionOutcome ) ;
165- Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectNotFoundFailureMessage ) ) ;
178+ Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( expectedFailureMessage ) ) ;
179+
180+ // Verify all mock responses were consumed
181+ Assert . AreEqual ( 0 , _httpManager . QueueSize ) ;
166182 }
167183
168184 [ TestMethod ]
@@ -303,10 +319,18 @@ public async Task ResponseMissingRegionFromLocalImdsAsync()
303319 Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectOkFailureMessage ) ) ;
304320 }
305321
306- [ TestMethod ]
307- public async Task ErrorResponseFromLocalImdsAsync ( )
322+ [ DataTestMethod ]
323+ [ DataRow ( HttpStatusCode . NotFound , 0 , TestConstants . RegionAutoDetectNotFoundFailureMessage ) ] // No retries for 404 errors
324+ [ DataRow ( HttpStatusCode . InternalServerError , TestRegionDiscoveryRetryPolicy . NumRetries , TestConstants . RegionAutoDetectInternalServerErrorFailureMessage ) ]
325+ public async Task ErrorResponseFromLocalImdsAsync (
326+ HttpStatusCode statusCode ,
327+ int expectedRetries ,
328+ string expectedFailureMessage )
308329 {
309- AddMockedResponse ( MockHelpers . CreateNullMessage ( System . Net . HttpStatusCode . NotFound ) ) ;
330+ for ( int i = 0 ; i < ( 1 + expectedRetries ) ; i ++ )
331+ {
332+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
333+ }
310334 _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
311335
312336 InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider .
@@ -318,7 +342,10 @@ public async Task ErrorResponseFromLocalImdsAsync()
318342 Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
319343 Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
320344 Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
321- Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectNotFoundFailureMessage ) ) ;
345+ Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( expectedFailureMessage ) ) ;
346+
347+ // Verify all mock responses were consumed
348+ Assert . AreEqual ( 0 , _httpManager . QueueSize ) ;
322349 }
323350
324351 [ TestMethod ]
@@ -382,6 +409,75 @@ public async Task UpdateApiversionFailsWithNoNewestVersionsAsync()
382409 Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionDiscoveryNotSupportedErrorMessage ) ) ;
383410 }
384411
412+ [ TestMethod ]
413+ public async Task RegionDiscoveryFails500OnceThenSucceeds200Async ( )
414+ {
415+ AddMockedResponse ( MockHelpers . CreateNullMessage ( HttpStatusCode . InternalServerError ) ) ;
416+ AddMockedResponse ( MockHelpers . CreateSuccessResponseMessage ( TestConstants . Region ) ) ;
417+
418+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
419+
420+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
421+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
422+
423+ ValidateInstanceMetadata ( regionalMetadata ) ;
424+ Assert . AreEqual ( TestConstants . Region , _testRequestContext . ApiEvent . RegionUsed ) ;
425+ Assert . AreEqual ( RegionAutodetectionSource . Imds , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
426+ Assert . AreEqual ( RegionOutcome . AutodetectSuccess , _testRequestContext . ApiEvent . RegionOutcome ) ;
427+ Assert . IsNull ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason ) ;
428+
429+ const int NumRequests = 2 ; // initial request + one retry
430+ int requestsMade = NumRequests - _httpManager . QueueSize ;
431+ Assert . AreEqual ( NumRequests , requestsMade ) ;
432+ }
433+
434+ [ TestMethod ]
435+ public async Task RegionDiscoveryFails500PermanentlyAsync ( )
436+ {
437+ // Simulate permanent 500s (to trigger the maximum number of retries)
438+ const int Num500Errors = 1 + RegionDiscoveryRetryPolicy . NumRetries ; // initial request + maximum number of retries
439+ for ( int i = 0 ; i < Num500Errors ; i ++ )
440+ {
441+ AddMockedResponse ( MockHelpers . CreateNullMessage ( HttpStatusCode . InternalServerError ) ) ;
442+ }
443+
444+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
445+
446+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
447+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
448+
449+ Assert . IsNull ( regionalMetadata , "Discovery should fail after max retries" ) ;
450+ Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
451+ Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
452+ Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
453+
454+ const int NumRequests = Num500Errors ; // initial request + three retries
455+ int requestsMade = NumRequests - _httpManager . QueueSize ;
456+ Assert . AreEqual ( NumRequests , requestsMade ) ;
457+ }
458+
459+ [ DataTestMethod ]
460+ [ DataRow ( HttpStatusCode . NotFound ) ]
461+ [ DataRow ( HttpStatusCode . RequestTimeout ) ]
462+ public async Task RegionDiscoveryDoesNotRetryOnNonRetryableStatusCodesAsync ( HttpStatusCode statusCode )
463+ {
464+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
465+
466+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
467+
468+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
469+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
470+
471+ Assert . IsNull ( regionalMetadata , "Discovery should fail and not retry" ) ;
472+ Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
473+ Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
474+ Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
475+
476+ const int NumRequests = 1 ; // initial request + 0 retries (non-retryable status codes should not trigger retry)
477+ int requestsMade = NumRequests - _httpManager . QueueSize ;
478+ Assert . AreEqual ( NumRequests , requestsMade ) ;
479+ }
480+
385481 private void AddMockedResponse ( HttpResponseMessage responseMessage , string apiVersion = "2020-06-01" , bool expectedParams = true )
386482 {
387483 var queryParams = new Dictionary < string , string > ( ) ;
0 commit comments