|
5 | 5 | using System.Collections.Generic; |
6 | 6 | using System.IO; |
7 | 7 | using System.Text; |
| 8 | +using System.Threading; |
8 | 9 | using System.Threading.Tasks; |
9 | 10 | using Azure.Core; |
10 | 11 | using Azure.Core.Pipeline; |
@@ -174,6 +175,64 @@ public void ManagedIdentityCredentialUsesDefaultTimeoutAndRetries() |
174 | 175 | CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); |
175 | 176 | } |
176 | 177 |
|
| 178 | + [Test] |
| 179 | + public void ManagedIdentityCredentialRetryBehaviorIsOverriddenWithOptions() |
| 180 | + { |
| 181 | + int callCount = 0; |
| 182 | + List<TimeSpan?> networkTimeouts = new(); |
| 183 | + |
| 184 | + var mockTransport = MockTransport.FromMessageCallback(msg => |
| 185 | + { |
| 186 | + callCount++; |
| 187 | + networkTimeouts.Add(msg.NetworkTimeout); |
| 188 | + Assert.IsTrue(msg.Request.Headers.TryGetValue(ImdsManagedIdentitySource.metadataHeaderName, out _)); |
| 189 | + return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json"); |
| 190 | + }); |
| 191 | + |
| 192 | + var options = new TokenCredentialOptions() |
| 193 | + { |
| 194 | + Transport = mockTransport, |
| 195 | + RetryPolicy = new RetryPolicy(1, DelayStrategy.CreateFixedDelayStrategy(TimeSpan.Zero)) |
| 196 | + }; |
| 197 | + options.Retry.MaxDelay = TimeSpan.Zero; |
| 198 | + |
| 199 | + var cred = new ManagedIdentityCredential( |
| 200 | + "testCLientId", options); |
| 201 | + |
| 202 | + Assert.ThrowsAsync<AuthenticationFailedException>(async () => await cred.GetTokenAsync(new(new[] { "test" }))); |
| 203 | + |
| 204 | + var expectedTimeouts = new TimeSpan?[] { null, null }; |
| 205 | + CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); |
| 206 | + } |
| 207 | + |
| 208 | + [Test] |
| 209 | + public void ManagedIdentityCredentialRespectsCancellationToken() |
| 210 | + { |
| 211 | + int callCount = 0; |
| 212 | + |
| 213 | + var mockTransport = MockTransport.FromMessageCallback(msg => |
| 214 | + { |
| 215 | + Task.Delay(1000).GetAwaiter().GetResult(); |
| 216 | + callCount++; |
| 217 | + return CreateMockResponse(500, "Error").WithHeader("Content-Type", "application/json"); |
| 218 | + }); |
| 219 | + |
| 220 | + var options = new TokenCredentialOptions() { Transport = mockTransport }; |
| 221 | + options.Retry.MaxDelay = TimeSpan.FromSeconds(1); |
| 222 | + |
| 223 | + var cred = new ManagedIdentityCredential( |
| 224 | + "testCLientId", options); |
| 225 | + |
| 226 | + var cts = new CancellationTokenSource(); |
| 227 | + cts.CancelAfter(TimeSpan.Zero); |
| 228 | + var ex = Assert.CatchAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }), cts.Token)); |
| 229 | + Assert.IsTrue(ex is TaskCanceledException || ex is OperationCanceledException, "Expected TaskCanceledException or OperationCanceledException but got " + ex.GetType().ToString()); |
| 230 | + |
| 231 | + // Default number of retries is 5, so we should just ensure we have less than that. |
| 232 | + // Timing on some platforms makes this test somewhat non-deterministic, so we just ensure we have less than 2 calls. |
| 233 | + Assert.Less(callCount, 2); |
| 234 | + } |
| 235 | + |
177 | 236 | private MockResponse CreateMockResponse(int responseCode, string token) |
178 | 237 | { |
179 | 238 | var response = new MockResponse(responseCode); |
|
0 commit comments