4
4
using System ;
5
5
using System . Collections . Generic ;
6
6
using System . IO ;
7
+ using System . Net ;
7
8
using System . Net . Http ;
8
9
using System . Threading ;
9
10
using System . Threading . Tasks ;
10
11
using Microsoft . Identity . Client ;
12
+ using Microsoft . Identity . Client . Http . Retry ;
11
13
using Microsoft . Identity . Client . Instance . Discovery ;
12
14
using Microsoft . Identity . Client . Internal ;
13
15
using Microsoft . Identity . Client . Region ;
14
16
using Microsoft . Identity . Client . TelemetryCore . Internal . Events ;
15
17
using Microsoft . Identity . Test . Common . Core . Mocks ;
18
+ using Microsoft . Identity . Test . Unit . Helpers ;
16
19
using Microsoft . VisualStudio . TestTools . UnitTesting ;
17
20
18
21
namespace Microsoft . Identity . Test . Unit . CoreTests
@@ -28,13 +31,15 @@ public class RegionDiscoveryProviderTests : TestBase
28
31
private ApiEvent _apiEvent ;
29
32
private CancellationTokenSource _userCancellationTokenSource ;
30
33
private IRegionDiscoveryProvider _regionDiscoveryProvider ;
34
+ private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory ( ) ;
31
35
32
36
[ TestInitialize ]
33
37
public override void TestInitialize ( )
34
38
{
35
39
base . TestInitialize ( ) ;
36
40
37
41
_harness = base . CreateTestHarness ( ) ;
42
+ _harness . ServiceBundle . Config . RetryPolicyFactory = _testRetryPolicyFactory ;
38
43
_httpManager = _harness . HttpManager ;
39
44
_userCancellationTokenSource = new CancellationTokenSource ( ) ;
40
45
_testRequestContext = new RequestContext (
@@ -145,10 +150,18 @@ public async Task FetchRegionFromLocalImdsThenGetMetadataFromCacheAsync()
145
150
ValidateInstanceMetadata ( regionalMetadata ) ;
146
151
}
147
152
148
- [ TestMethod ]
149
- public async Task SuccessfulResponseFromUserProvidedRegionAsync ( )
153
+ [ DataTestMethod ]
154
+ [ DataRow ( HttpStatusCode . NotFound , 0 , TestConstants . RegionAutoDetectNotFoundFailureMessage ) ] // No retries for 404 errors
155
+ [ DataRow ( HttpStatusCode . InternalServerError , TestRegionDiscoveryRetryPolicy . NumRetries , TestConstants . RegionAutoDetectInternalServerErrorFailureMessage ) ]
156
+ public async Task SuccessfulResponseFromUserProvidedRegionAsync (
157
+ HttpStatusCode statusCode ,
158
+ int expectedRetries ,
159
+ string expectedFailureMessage )
150
160
{
151
- AddMockedResponse ( MockHelpers . CreateNullMessage ( System . Net . HttpStatusCode . NotFound ) ) ;
161
+ for ( int i = 0 ; i < ( 1 + expectedRetries ) ; i ++ )
162
+ {
163
+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
164
+ }
152
165
153
166
_testRequestContext . ServiceBundle . Config . AzureRegion = TestConstants . Region ;
154
167
@@ -162,7 +175,10 @@ public async Task SuccessfulResponseFromUserProvidedRegionAsync()
162
175
Assert . AreEqual ( TestConstants . Region , _testRequestContext . ApiEvent . RegionUsed ) ;
163
176
Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
164
177
Assert . AreEqual ( RegionOutcome . UserProvidedAutodetectionFailed , _testRequestContext . ApiEvent . RegionOutcome ) ;
165
- Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectNotFoundFailureMessage ) ) ;
178
+ Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( expectedFailureMessage ) ) ;
179
+
180
+ // Verify all mock responses were consumed
181
+ Assert . AreEqual ( 0 , _httpManager . QueueSize ) ;
166
182
}
167
183
168
184
[ TestMethod ]
@@ -303,10 +319,18 @@ public async Task ResponseMissingRegionFromLocalImdsAsync()
303
319
Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectOkFailureMessage ) ) ;
304
320
}
305
321
306
- [ TestMethod ]
307
- public async Task ErrorResponseFromLocalImdsAsync ( )
322
+ [ DataTestMethod ]
323
+ [ DataRow ( HttpStatusCode . NotFound , 0 , TestConstants . RegionAutoDetectNotFoundFailureMessage ) ] // No retries for 404 errors
324
+ [ DataRow ( HttpStatusCode . InternalServerError , TestRegionDiscoveryRetryPolicy . NumRetries , TestConstants . RegionAutoDetectInternalServerErrorFailureMessage ) ]
325
+ public async Task ErrorResponseFromLocalImdsAsync (
326
+ HttpStatusCode statusCode ,
327
+ int expectedRetries ,
328
+ string expectedFailureMessage )
308
329
{
309
- AddMockedResponse ( MockHelpers . CreateNullMessage ( System . Net . HttpStatusCode . NotFound ) ) ;
330
+ for ( int i = 0 ; i < ( 1 + expectedRetries ) ; i ++ )
331
+ {
332
+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
333
+ }
310
334
_testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
311
335
312
336
InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider .
@@ -318,7 +342,10 @@ public async Task ErrorResponseFromLocalImdsAsync()
318
342
Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
319
343
Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
320
344
Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
321
- Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionAutoDetectNotFoundFailureMessage ) ) ;
345
+ Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( expectedFailureMessage ) ) ;
346
+
347
+ // Verify all mock responses were consumed
348
+ Assert . AreEqual ( 0 , _httpManager . QueueSize ) ;
322
349
}
323
350
324
351
[ TestMethod ]
@@ -382,6 +409,75 @@ public async Task UpdateApiversionFailsWithNoNewestVersionsAsync()
382
409
Assert . IsTrue ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason . Contains ( TestConstants . RegionDiscoveryNotSupportedErrorMessage ) ) ;
383
410
}
384
411
412
+ [ TestMethod ]
413
+ public async Task RegionDiscoveryFails500OnceThenSucceeds200Async ( )
414
+ {
415
+ AddMockedResponse ( MockHelpers . CreateNullMessage ( HttpStatusCode . InternalServerError ) ) ;
416
+ AddMockedResponse ( MockHelpers . CreateSuccessResponseMessage ( TestConstants . Region ) ) ;
417
+
418
+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
419
+
420
+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
421
+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
422
+
423
+ ValidateInstanceMetadata ( regionalMetadata ) ;
424
+ Assert . AreEqual ( TestConstants . Region , _testRequestContext . ApiEvent . RegionUsed ) ;
425
+ Assert . AreEqual ( RegionAutodetectionSource . Imds , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
426
+ Assert . AreEqual ( RegionOutcome . AutodetectSuccess , _testRequestContext . ApiEvent . RegionOutcome ) ;
427
+ Assert . IsNull ( _testRequestContext . ApiEvent . RegionDiscoveryFailureReason ) ;
428
+
429
+ const int NumRequests = 2 ; // initial request + one retry
430
+ int requestsMade = NumRequests - _httpManager . QueueSize ;
431
+ Assert . AreEqual ( NumRequests , requestsMade ) ;
432
+ }
433
+
434
+ [ TestMethod ]
435
+ public async Task RegionDiscoveryFails500PermanentlyAsync ( )
436
+ {
437
+ // Simulate permanent 500s (to trigger the maximum number of retries)
438
+ const int Num500Errors = 1 + RegionDiscoveryRetryPolicy . NumRetries ; // initial request + maximum number of retries
439
+ for ( int i = 0 ; i < Num500Errors ; i ++ )
440
+ {
441
+ AddMockedResponse ( MockHelpers . CreateNullMessage ( HttpStatusCode . InternalServerError ) ) ;
442
+ }
443
+
444
+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
445
+
446
+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
447
+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
448
+
449
+ Assert . IsNull ( regionalMetadata , "Discovery should fail after max retries" ) ;
450
+ Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
451
+ Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
452
+ Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
453
+
454
+ const int NumRequests = Num500Errors ; // initial request + three retries
455
+ int requestsMade = NumRequests - _httpManager . QueueSize ;
456
+ Assert . AreEqual ( NumRequests , requestsMade ) ;
457
+ }
458
+
459
+ [ DataTestMethod ]
460
+ [ DataRow ( HttpStatusCode . NotFound ) ]
461
+ [ DataRow ( HttpStatusCode . RequestTimeout ) ]
462
+ public async Task RegionDiscoveryDoesNotRetryOnNonRetryableStatusCodesAsync ( HttpStatusCode statusCode )
463
+ {
464
+ AddMockedResponse ( MockHelpers . CreateNullMessage ( statusCode ) ) ;
465
+
466
+ _testRequestContext . ServiceBundle . Config . AzureRegion = ConfidentialClientApplication . AttemptRegionDiscovery ;
467
+
468
+ InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider . GetMetadataAsync (
469
+ new Uri ( "https://login.microsoftonline.com/common/" ) , _testRequestContext ) . ConfigureAwait ( false ) ;
470
+
471
+ Assert . IsNull ( regionalMetadata , "Discovery should fail and not retry" ) ;
472
+ Assert . AreEqual ( null , _testRequestContext . ApiEvent . RegionUsed ) ;
473
+ Assert . AreEqual ( RegionAutodetectionSource . FailedAutoDiscovery , _testRequestContext . ApiEvent . RegionAutodetectionSource ) ;
474
+ Assert . AreEqual ( RegionOutcome . FallbackToGlobal , _testRequestContext . ApiEvent . RegionOutcome ) ;
475
+
476
+ const int NumRequests = 1 ; // initial request + 0 retries (non-retryable status codes should not trigger retry)
477
+ int requestsMade = NumRequests - _httpManager . QueueSize ;
478
+ Assert . AreEqual ( NumRequests , requestsMade ) ;
479
+ }
480
+
385
481
private void AddMockedResponse ( HttpResponseMessage responseMessage , string apiVersion = "2020-06-01" , bool expectedParams = true )
386
482
{
387
483
var queryParams = new Dictionary < string , string > ( ) ;
0 commit comments