Skip to content

Commit d65173b

Browse files
Implement public client application global cache (#770)
1 parent 2134081 commit d65173b

File tree

1 file changed

+140
-38
lines changed

1 file changed

+140
-38
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 140 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Concurrent;
67
using System.Linq;
78
using System.Security;
89
using System.Threading;
@@ -15,6 +16,14 @@ namespace Microsoft.Data.SqlClient
1516
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/ActiveDirectoryAuthenticationProvider/*'/>
1617
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
1718
{
19+
/// <summary>
20+
/// This is a static cache instance meant to hold instances of "PublicClientApplication" mapping to information available in PublicClientAppKey.
21+
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
22+
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
23+
/// </summary>
24+
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
25+
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
26+
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
1827
private static readonly string s_defaultScopeSuffix = "/.default";
1928
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
2029
private readonly SqlClientLogger _logger = new SqlClientLogger();
@@ -67,10 +76,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
6776
}
6877

6978
#if NETSTANDARD
70-
private Func<object> parentActivityOrWindowFunc = null;
79+
private Func<object> _parentActivityOrWindowFunc = null;
7180

7281
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetParentActivityOrWindowFunc/*'/>
73-
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
82+
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
7483
#endif
7584

7685
#if NETFRAMEWORK
@@ -108,51 +117,24 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
108117
*
109118
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
110119
*/
111-
string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
120+
string redirectUri = s_nativeClientRedirectUri;
112121

113122
#if NETCOREAPP
114123
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
115124
{
116-
redirectURI = "http://localhost";
117-
}
118-
#endif
119-
IPublicClientApplication app;
120-
121-
#if NETSTANDARD
122-
if (parentActivityOrWindowFunc != null)
123-
{
124-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
125-
.WithAuthority(parameters.Authority)
126-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
127-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
128-
.WithRedirectUri(redirectURI)
129-
.WithParentActivityOrWindow(parentActivityOrWindowFunc)
130-
.Build();
125+
redirectUri = "http://localhost";
131126
}
132127
#endif
128+
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
133129
#if NETFRAMEWORK
134-
if (_iWin32WindowFunc != null)
135-
{
136-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
137-
.WithAuthority(parameters.Authority)
138-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
139-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
140-
.WithRedirectUri(redirectURI)
141-
.WithParentActivityOrWindow(_iWin32WindowFunc)
142-
.Build();
143-
}
130+
, _iWin32WindowFunc
144131
#endif
145-
#if !NETCOREAPP
146-
else
132+
#if NETSTANDARD
133+
, _parentActivityOrWindowFunc
147134
#endif
148-
{
149-
app = PublicClientApplicationBuilder.Create(_applicationClientId)
150-
.WithAuthority(parameters.Authority)
151-
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
152-
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
153-
.WithRedirectUri(redirectURI)
154-
.Build();
155-
}
135+
);
136+
137+
IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
156138

157139
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
158140
{
@@ -185,6 +167,7 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
185167
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
186168
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
187169
{
170+
// Fetch available accounts from 'app' instance
188171
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
189172
IAccount account;
190173
if (!string.IsNullOrEmpty(parameters.UserId))
@@ -200,17 +183,23 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
200183
{
201184
try
202185
{
186+
// If 'account' is available in 'app', we use the same to acquire token silently.
187+
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
203188
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
204189
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
205190
}
206191
catch (MsalUiRequiredException)
207192
{
193+
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
194+
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
195+
// or the user needs to perform two factor authentication.
208196
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
209197
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
210198
}
211199
}
212200
else
213201
{
202+
// If no existing 'account' is found, we request user to sign in interactively.
214203
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
215204
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
216205
}
@@ -320,5 +309,118 @@ private class CustomWebUi : ICustomWebUi
320309
public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
321310
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
322311
}
312+
313+
private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
314+
{
315+
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
316+
{
317+
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
318+
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
319+
}
320+
return clientApplicationInstance;
321+
}
322+
323+
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
324+
{
325+
IPublicClientApplication publicClientApplication;
326+
327+
#if NETSTANDARD
328+
if (_parentActivityOrWindowFunc != null)
329+
{
330+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
331+
.WithAuthority(publicClientAppKey._authority)
332+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
333+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
334+
.WithRedirectUri(publicClientAppKey._redirectUri)
335+
.WithParentActivityOrWindow(_parentActivityOrWindowFunc)
336+
.Build();
337+
}
338+
#endif
339+
#if NETFRAMEWORK
340+
if (_iWin32WindowFunc != null)
341+
{
342+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
343+
.WithAuthority(publicClientAppKey._authority)
344+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
345+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
346+
.WithRedirectUri(publicClientAppKey._redirectUri)
347+
.WithParentActivityOrWindow(_iWin32WindowFunc)
348+
.Build();
349+
}
350+
#endif
351+
#if !NETCOREAPP
352+
else
353+
#endif
354+
{
355+
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
356+
.WithAuthority(publicClientAppKey._authority)
357+
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
358+
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
359+
.WithRedirectUri(publicClientAppKey._redirectUri)
360+
.Build();
361+
}
362+
363+
return publicClientApplication;
364+
}
365+
366+
internal class PublicClientAppKey
367+
{
368+
public readonly string _authority;
369+
public readonly string _redirectUri;
370+
public readonly string _applicationClientId;
371+
#if NETFRAMEWORK
372+
public readonly Func<System.Windows.Forms.IWin32Window> _iWin32WindowFunc;
373+
#endif
374+
#if NETSTANDARD
375+
public readonly Func<object> _parentActivityOrWindowFunc;
376+
#endif
377+
378+
public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
379+
#if NETFRAMEWORK
380+
, Func<System.Windows.Forms.IWin32Window> iWin32WindowFunc
381+
#endif
382+
#if NETSTANDARD
383+
, Func<object> parentActivityOrWindowFunc
384+
#endif
385+
)
386+
{
387+
_authority = authority;
388+
_redirectUri = redirectUri;
389+
_applicationClientId = applicationClientId;
390+
#if NETFRAMEWORK
391+
_iWin32WindowFunc = iWin32WindowFunc;
392+
#endif
393+
#if NETSTANDARD
394+
_parentActivityOrWindowFunc = parentActivityOrWindowFunc;
395+
#endif
396+
}
397+
398+
public override bool Equals(object obj)
399+
{
400+
if (obj != null && obj is PublicClientAppKey pcaKey)
401+
{
402+
return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
403+
&& string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
404+
&& string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
405+
#if NETFRAMEWORK
406+
&& pcaKey._iWin32WindowFunc == _iWin32WindowFunc
407+
#endif
408+
#if NETSTANDARD
409+
&& pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
410+
#endif
411+
);
412+
}
413+
return false;
414+
}
415+
416+
public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _applicationClientId
417+
#if NETFRAMEWORK
418+
, _iWin32WindowFunc
419+
#endif
420+
#if NETSTANDARD
421+
, _parentActivityOrWindowFunc
422+
#endif
423+
).GetHashCode();
424+
}
323425
}
324426
}

0 commit comments

Comments
 (0)