Skip to content

Commit 631ef39

Browse files
Add Concurrency Test for Multi-Tenant AcquireTokenForClient Scenario (#5192)
* Add Concurrency Test for Multi-Tenant AcquireTokenForClient Scenario * fix per pr comments --------- Co-authored-by: Gladwin Johnson <[email protected]>
1 parent 6025646 commit 631ef39

File tree

2 files changed

+150
-3
lines changed

2 files changed

+150
-3
lines changed

tests/Microsoft.Identity.Test.Unit/Helpers/ParallelRequestMockHandler.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace Microsoft.Identity.Test.Unit.Helpers
2626
internal class ParallelRequestMockHandler : IHttpManager
2727
{
2828
public long LastRequestDurationInMs => 50;
29+
private int _requestCount = 0;
30+
public int RequestsMade => _requestCount;
2931

3032
public async Task<HttpResponse> SendRequestAsync(
3133
Uri endpoint,
@@ -39,6 +41,8 @@ public async Task<HttpResponse> SendRequestAsync(
3941
CancellationToken cancellationToken,
4042
int retryCount = 0)
4143
{
44+
Interlocked.Increment(ref _requestCount);
45+
4246
// simulate delay and also add complexity due to thread context switch
4347
await Task.Delay(ParallelRequestsTests.NetworkAccessPenaltyMs).ConfigureAwait(false);
4448

@@ -53,10 +57,13 @@ public async Task<HttpResponse> SendRequestAsync(
5357
}
5458

5559
if (HttpMethod.Post == method &&
56-
UriWithoutQuery(endpoint).AbsoluteUri.Equals("https://login.microsoftonline.com/my-utid/oauth2/v2.0/token"))
60+
UriWithoutQuery(endpoint).AbsoluteUri.EndsWith("oauth2/v2.0/token"))
5761
{
5862
var bodyString = await (body as FormUrlEncodedContent).ReadAsStringAsync().ConfigureAwait(false);
59-
var bodyDict = bodyString.Replace("?", "").Split('&').ToDictionary(x => x.Split('=')[0], x => x.Split('=')[1]);
63+
var bodyDict = bodyString
64+
.Replace("?", "")
65+
.Split('&')
66+
.ToDictionary(x => x.Split('=')[0], x => x.Split('=')[1]);
6067

6168
if (bodyDict["grant_type"] == "refresh_token")
6269
{
@@ -71,7 +78,12 @@ public async Task<HttpResponse> SendRequestAsync(
7178

7279
if (bodyDict["grant_type"] == "client_credentials")
7380
{
74-
HttpResponseMessage response = MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage();
81+
var segments = endpoint.AbsolutePath.Split('/');
82+
string tid = segments.Length > 1 ? segments[1] : "unknown_tid";
83+
84+
HttpResponseMessage response =
85+
MockHelpers.CreateSuccessfulClientCredentialTokenResponseMessage($"token_{tid}");
86+
7587
string payload = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
7688

7789
return new HttpResponse()

tests/Microsoft.Identity.Test.Unit/ParallelRequestsTests.cs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,141 @@ public async Task ExtraQP()
8585
Assert.AreEqual(NumberOfRequests, results.Length);
8686
}
8787

88+
[TestMethod]
89+
public async Task AcquireTokenForClient_ConcurrentTenantRequests_Test()
90+
{
91+
// Arrange
92+
const int NumberOfRequests = 1000;
93+
94+
// Custom HTTP manager that counts the number of requests
95+
ParallelRequestMockHandler httpManager = new();
96+
97+
var cca = ConfidentialClientApplicationBuilder
98+
.Create(TestConstants.ClientId)
99+
.WithAuthority("https://login.microsoftonline.com/common")
100+
.WithClientSecret(TestConstants.ClientSecret)
101+
.WithHttpManager(httpManager)
102+
.Build();
103+
104+
var tasks = new List<Task<AuthenticationResult>>();
105+
106+
for (int i = 0; i < NumberOfRequests; i++)
107+
{
108+
int tempI = i; // Capture the current value of i
109+
tasks.Add(Task.Run(async () =>
110+
{
111+
string tid = $"tidtid_{tempI}";
112+
AuthenticationResult res = await cca.AcquireTokenForClient(TestConstants.s_scope)
113+
.WithTenantId(tid)
114+
.ExecuteAsync()
115+
.ConfigureAwait(false);
116+
117+
Assert.IsFalse(
118+
string.IsNullOrEmpty(res.AuthenticationResultMetadata.TokenEndpoint),
119+
"TokenEndpoint is null/empty!"
120+
);
121+
Assert.IsTrue(
122+
res.AuthenticationResultMetadata.TokenEndpoint.Contains(tid),
123+
"TokenEndpoint should contain the tenant ID."
124+
);
125+
Assert.AreEqual($"token_{tid}", res.AccessToken, "Access token did not match the expected value.");
126+
127+
return res;
128+
}));
129+
}
130+
131+
// Wait for all tasks to complete
132+
AuthenticationResult[] results = await Task.WhenAll(tasks).ConfigureAwait(false);
133+
134+
// Assert the total tasks
135+
Assert.AreEqual(NumberOfRequests, results.Length, "Number of AuthenticationResult objects does not match the number of requests.");
136+
}
137+
138+
[TestMethod]
139+
public async Task AcquireTokenForClient_PerTenantCaching_Test()
140+
{
141+
const int NumberOfRequests = 5000;
142+
143+
var httpManager = new ParallelRequestMockHandler();
144+
IConfidentialClientApplication cca = ConfidentialClientApplicationBuilder
145+
.Create(TestConstants.ClientId)
146+
.WithAuthority("https://login.microsoftonline.com/common")
147+
.WithClientSecret(TestConstants.ClientSecret)
148+
.WithHttpManager(httpManager)
149+
.Build();
150+
151+
// First pass: tokens should come from the network
152+
var tasksFirstPass = new List<Task<AuthenticationResult>>();
153+
for (int i = 0; i < NumberOfRequests; i++)
154+
{
155+
int tempI = i; // Capture the current value of i
156+
string tid = $"tidtid_{tempI}";
157+
tasksFirstPass.Add(Task.Run(async () =>
158+
{
159+
AuthenticationResult result = await cca
160+
.AcquireTokenForClient(TestConstants.s_scope)
161+
.WithTenantId(tid)
162+
.ExecuteAsync()
163+
.ConfigureAwait(false);
164+
165+
Assert.IsNotNull(result, $"First-pass result is null for TID '{tid}'.");
166+
Assert.IsFalse(
167+
string.IsNullOrEmpty(result.AccessToken),
168+
$"First-pass access token is null/empty for TID '{tid}'.");
169+
Assert.AreEqual(
170+
$"token_{tid}",
171+
result.AccessToken,
172+
$"First-pass AccessToken mismatch for TID '{tid}'.");
173+
Assert.IsTrue(
174+
result.AuthenticationResultMetadata.TokenEndpoint.Contains(tid),
175+
$"First-pass TokenEndpoint '{result.AuthenticationResultMetadata.TokenEndpoint}' does not contain TID '{tid}'.");
176+
177+
return result;
178+
}));
179+
}
180+
181+
AuthenticationResult[] firstPassResults = await Task.WhenAll(tasksFirstPass).ConfigureAwait(false);
182+
int firstPassRequestsMade = httpManager.RequestsMade;
183+
184+
// Second pass: tokens should come from the cache
185+
var tasksSecondPass = new List<Task<AuthenticationResult>>();
186+
for (int i = 0; i < NumberOfRequests; i++)
187+
{
188+
int tempI = i; // Capture the current value of i
189+
string tid = $"tidtid_{tempI}";
190+
tasksSecondPass.Add(Task.Run(async () =>
191+
{
192+
AuthenticationResult result = await cca
193+
.AcquireTokenForClient(TestConstants.s_scope)
194+
.WithTenantId(tid)
195+
.ExecuteAsync()
196+
.ConfigureAwait(false);
197+
198+
Assert.IsNotNull(result, $"Second-pass result is null for TID '{tid}'.");
199+
Assert.IsFalse(
200+
string.IsNullOrEmpty(result.AccessToken),
201+
$"Second-pass access token is null/empty for TID '{tid}'.");
202+
Assert.AreEqual(
203+
$"token_{tid}",
204+
result.AccessToken,
205+
$"Second-pass AccessToken mismatch for TID '{tid}'.");
206+
207+
return result;
208+
}));
209+
}
210+
211+
AuthenticationResult[] secondPassResults = await Task.WhenAll(tasksSecondPass).ConfigureAwait(false);
212+
int totalRequestsMade = httpManager.RequestsMade;
213+
int secondPassRequestsMade = totalRequestsMade - firstPassRequestsMade;
214+
215+
// Verifying no new network calls on the second pass if caching is working properly
216+
Assert.AreEqual(
217+
0,
218+
secondPassRequestsMade,
219+
$"Expected zero new requests in second pass, but found {secondPassRequestsMade}."
220+
);
221+
}
222+
88223
[TestMethod]
89224
public async Task AcquireTokenSilent_ValidATs_ParallelRequests_Async()
90225
{

0 commit comments

Comments
 (0)