Skip to content

Commit 3487389

Browse files
authored
Fix | Fix driver to not send expired token and refresh token first before sending it. (#2273)
1 parent 9347412 commit 3487389

File tree

5 files changed

+195
-1
lines changed

5 files changed

+195
-1
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,6 +2256,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
22562256
{
22572257
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
22582258
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");
2259+
2260+
if (_newDbConnectionPoolAuthenticationContext != null)
2261+
{
2262+
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
2263+
}
22592264
}
22602265
}
22612266
else if (!attemptRefreshTokenLocked)

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,6 +2683,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
26832683
{
26842684
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
26852685
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");
2686+
2687+
if (_newDbConnectionPoolAuthenticationContext != null)
2688+
{
2689+
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
2690+
}
26862691
}
26872692
}
26882693
else if (!attemptRefreshTokenLocked)

src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
<Compile Include="ProviderAgnostic\MultipleResultsTest\MultipleResultsTest.cs" />
175175
<Compile Include="ProviderAgnostic\ReaderTest\ReaderTest.cs" />
176176
<Compile Include="TracingTests\EventSourceTest.cs" />
177+
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
177178
<Compile Include="SQL\ConnectionPoolTest\ConnectionPoolTest.cs" />
178179
<Compile Include="SQL\ConnectionPoolTest\PoolBlockPeriodTest.cs" />
179180
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
@@ -279,6 +280,7 @@
279280
<Compile Include="SQL\Common\SystemDataInternals\ConnectionHelper.cs" />
280281
<Compile Include="SQL\Common\SystemDataInternals\ConnectionPoolHelper.cs" />
281282
<Compile Include="SQL\Common\SystemDataInternals\DataReaderHelper.cs" />
283+
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
282284
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
283285
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
284286
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
@@ -342,7 +344,7 @@
342344
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="$(SystemIdentityModelTokensJwtVersion)" />
343345
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
344346
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
345-
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
347+
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
346348
<PackageReference Condition="'$(TargetGroup)'!='netfx'" Include="System.ServiceProcess.ServiceController" Version="$(SystemServiceProcessServiceControllerVersion)" />
347349
</ItemGroup>
348350
<ItemGroup>
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
7+
using Xunit;
8+
using Xunit.Abstractions;
9+
10+
namespace Microsoft.Data.SqlClient.ManualTesting.Tests
11+
{
12+
public class AADFedAuthTokenRefreshTest
13+
{
14+
private readonly ITestOutputHelper _testOutputHelper;
15+
16+
public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
17+
{
18+
_testOutputHelper = testOutputHelper;
19+
}
20+
21+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
22+
public void FedAuthTokenRefreshTest()
23+
{
24+
string connectionString = DataTestUtility.AADPasswordConnectionString;
25+
26+
using (SqlConnection connection = new SqlConnection(connectionString))
27+
{
28+
connection.Open();
29+
30+
string oldTokenHash = "";
31+
DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
32+
Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");
33+
34+
// Convert and display the old expiry into local time which should be in 1 minute from now
35+
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
36+
LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
37+
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
38+
Assert.InRange(timeDiff.TotalSeconds, 0, 60);
39+
40+
// Check if connection is still alive to continue further testing
41+
string result = "";
42+
SqlCommand cmd = connection.CreateCommand();
43+
cmd.CommandText = "select @@version";
44+
result = $"{cmd.ExecuteScalar()}";
45+
Assert.True(result != string.Empty, "The connection's command must return a value");
46+
47+
// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
48+
using (SqlConnection connection2 = new SqlConnection(connectionString))
49+
{
50+
connection2.Open();
51+
52+
// Check if connection is alive
53+
cmd = connection2.CreateCommand();
54+
cmd.CommandText = "select 1";
55+
result = $"{cmd.ExecuteScalar()}";
56+
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");
57+
58+
string newTokenHash = "";
59+
DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
60+
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
61+
LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");
62+
63+
Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
64+
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
65+
}
66+
}
67+
}
68+
69+
private void LogInfo(string message)
70+
{
71+
_testOutputHelper.WriteLine(message);
72+
}
73+
}
74+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections;
7+
using System.Linq;
8+
using System.Reflection;
9+
10+
namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
11+
{
12+
internal static class FedAuthTokenHelper
13+
{
14+
internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
15+
{
16+
try
17+
{
18+
object authenticationContextValueObj = GetAuthenticationContextValue(connection);
19+
20+
DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
21+
22+
tokenHash = GetTokenHash(authenticationContextValueObj);
23+
24+
return expirationTimeProperty;
25+
}
26+
catch (Exception)
27+
{
28+
tokenHash = "";
29+
return null;
30+
}
31+
}
32+
33+
internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
34+
{
35+
try
36+
{
37+
object authenticationContextValueObj = GetAuthenticationContextValue(connection);
38+
39+
DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
40+
41+
expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);
42+
43+
FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
44+
expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);
45+
46+
tokenHash = GetTokenHash(authenticationContextValueObj);
47+
48+
return expirationTimeProperty;
49+
}
50+
catch (Exception)
51+
{
52+
tokenHash = "";
53+
return null;
54+
}
55+
}
56+
57+
internal static string GetTokenHash(object authenticationContextValueObj)
58+
{
59+
try
60+
{
61+
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));
62+
63+
Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");
64+
65+
Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };
66+
67+
ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
68+
69+
Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");
70+
71+
ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
72+
73+
object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });
74+
75+
MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);
76+
77+
byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
78+
79+
object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
80+
FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
81+
accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);
82+
83+
string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });
84+
85+
return tokenHash;
86+
}
87+
catch (Exception)
88+
{
89+
return "";
90+
}
91+
}
92+
93+
internal static object GetAuthenticationContextValue(SqlConnection connection)
94+
{
95+
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);
96+
97+
object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);
98+
99+
IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);
100+
101+
object authenticationContextObj = authenticationContexts.Cast<object>().FirstOrDefault();
102+
103+
object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null);
104+
105+
return authenticationContextValueObj;
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)