|
5 | 5 | using System; |
6 | 6 | using System.Net; |
7 | 7 | using System.Net.Sockets; |
| 8 | +using System.Reflection; |
8 | 9 | using System.Text; |
9 | 10 | using System.Threading.Tasks; |
10 | 11 | using Xunit; |
@@ -83,6 +84,138 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove |
83 | 84 | } |
84 | 85 | } |
85 | 86 |
|
| 87 | + // Note: This Unit test was tested in a domain-joined VM connecting to a remote |
| 88 | + // SQL Server using Kerberos in the same domain. |
| 89 | + [ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false |
| 90 | + [ConditionalFact(nameof(IsKerberos))] |
| 91 | + public static void PortNumberInSPNTest() |
| 92 | + { |
| 93 | + string connStr = DataTestUtility.TCPConnectionString; |
| 94 | + // If config.json.SupportsIntegratedSecurity = true, replace all keys defined below with Integrated Security=true |
| 95 | + if (DataTestUtility.IsIntegratedSecuritySetup()) |
| 96 | + { |
| 97 | + string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD", "Trusted_Connection" }; |
| 98 | + connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.TCPConnectionString, removeKeys) + $"Integrated Security=true"; |
| 99 | + } |
| 100 | + |
| 101 | + SqlConnectionStringBuilder builder = new(connStr); |
| 102 | + |
| 103 | + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName), "Data source to be parsed must contain a host name and instance name"); |
| 104 | + |
| 105 | + bool condition = IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName); |
| 106 | + Assert.True(condition, "Browser service is not running or instance name is invalid"); |
| 107 | + |
| 108 | + if (condition) |
| 109 | + { |
| 110 | + using SqlConnection connection = new(builder.ConnectionString); |
| 111 | + connection.Open(); |
| 112 | + using SqlCommand command = new("SELECT auth_scheme, local_tcp_port from sys.dm_exec_connections where session_id = @@spid", connection); |
| 113 | + using SqlDataReader reader = command.ExecuteReader(); |
| 114 | + Assert.True(reader.Read(), "Expected to receive one row data"); |
| 115 | + Assert.Equal("KERBEROS", reader.GetString(0)); |
| 116 | + int localTcpPort = reader.GetInt32(1); |
| 117 | + |
| 118 | + int spnPort = -1; |
| 119 | + string spnInfo = GetSPNInfo(builder.DataSource, out spnPort); |
| 120 | + |
| 121 | + // sample output to validate = MSSQLSvc/machine.domain.tld:spnPort" |
| 122 | + Assert.Contains($"MSSQLSvc/{hostname}", spnInfo); |
| 123 | + // the local_tcp_port should be the same as the inferred SPN port from instance name |
| 124 | + Assert.Equal(localTcpPort, spnPort); |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + private static string GetSPNInfo(string datasource, out int out_port) |
| 129 | + { |
| 130 | + Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection)); |
| 131 | + |
| 132 | + // Get all required types using reflection |
| 133 | + Type sniProxyType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SNIProxy"); |
| 134 | + Type ssrpType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.SSRP"); |
| 135 | + Type dataSourceType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SNI.DataSource"); |
| 136 | + Type timeoutTimerType = sqlConnectionAssembly.GetType("Microsoft.Data.ProviderBase.TimeoutTimer"); |
| 137 | + |
| 138 | + // Used in Datasource constructor param type array |
| 139 | + Type[] dataSourceConstructorTypesArray = new Type[] { typeof(string) }; |
| 140 | + |
| 141 | + // Used in GetSqlServerSPNs function param types array |
| 142 | + Type[] getSqlServerSPNsTypesArray = new Type[] { dataSourceType, typeof(string) }; |
| 143 | + |
| 144 | + // GetPortByInstanceName parameters array |
| 145 | + Type[] getPortByInstanceNameTypesArray = new Type[] { typeof(string), typeof(string), timeoutTimerType, typeof(bool), typeof(Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference) }; |
| 146 | + |
| 147 | + // TimeoutTimer.StartSecondsTimeout params |
| 148 | + Type[] startSecondsTimeoutTypesArray = new Type[] { typeof(int) }; |
| 149 | + |
| 150 | + // Get all types constructors |
| 151 | + ConstructorInfo sniProxyCtor = sniProxyType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); |
| 152 | + ConstructorInfo SSRPCtor = ssrpType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); |
| 153 | + ConstructorInfo dataSourceCtor = dataSourceType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); |
| 154 | + ConstructorInfo timeoutTimerCtor = timeoutTimerType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null); |
| 155 | + |
| 156 | + // Instantiate SNIProxy |
| 157 | + object sniProxy = sniProxyCtor.Invoke(new object[] { }); |
| 158 | + |
| 159 | + // Instantiate datasource |
| 160 | + object dataSourceObj = dataSourceCtor.Invoke(new object[] { datasource }); |
| 161 | + |
| 162 | + // Instantiate SSRP |
| 163 | + object ssrp = SSRPCtor.Invoke(new object[] { }); |
| 164 | + |
| 165 | + // Instantiate TimeoutTimer |
| 166 | + object timeoutTimer = timeoutTimerCtor.Invoke(new object[] { }); |
| 167 | + |
| 168 | + // Get TimeoutTimer.StartSecondsTimeout Method |
| 169 | + MethodInfo startSecondsTimeout = timeoutTimer.GetType().GetMethod("StartSecondsTimeout", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, startSecondsTimeoutTypesArray, null); |
| 170 | + // Create a timeoutTimer that expires in 30 seconds |
| 171 | + timeoutTimer = startSecondsTimeout.Invoke(dataSourceObj, new object[] { 30 }); |
| 172 | + |
| 173 | + // Parse the datasource to separate the server name and instance name |
| 174 | + MethodInfo ParseServerName = dataSourceObj.GetType().GetMethod("ParseServerName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, dataSourceConstructorTypesArray, null); |
| 175 | + object dataSrcInfo = ParseServerName.Invoke(dataSourceObj, new object[] { datasource }); |
| 176 | + |
| 177 | + // Get the GetPortByInstanceName method of SSRP |
| 178 | + MethodInfo getPortByInstanceName = ssrp.GetType().GetMethod("GetPortByInstanceName", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getPortByInstanceNameTypesArray, null); |
| 179 | + |
| 180 | + // Get the server name |
| 181 | + PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); |
| 182 | + string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString(); |
| 183 | + |
| 184 | + // Get the instance name |
| 185 | + PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); |
| 186 | + string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString(); |
| 187 | + |
| 188 | + // Get the port number using the GetPortByInstanceName method of SSRP |
| 189 | + object port = getPortByInstanceName.Invoke(ssrp, parameters: new object[] { serverName, instanceName, timeoutTimer, false, 0 }); |
| 190 | + |
| 191 | + // Set the resolved port property of datasource |
| 192 | + PropertyInfo resolvedPortInfo = dataSrcInfo.GetType().GetProperty("ResolvedPort", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); |
| 193 | + resolvedPortInfo.SetValue(dataSrcInfo, (int)port, null); |
| 194 | + |
| 195 | + // Prepare the GetSqlServerSPNs method |
| 196 | + string serverSPN = ""; |
| 197 | + MethodInfo getSqlServerSPNs = sniProxy.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); |
| 198 | + |
| 199 | + // Finally call GetSqlServerSPNs |
| 200 | + byte[][] result = (byte[][])getSqlServerSPNs.Invoke(sniProxy, new object[] { dataSrcInfo, serverSPN }); |
| 201 | + |
| 202 | + // Example result: MSSQLSvc/machine.domain.tld:port" |
| 203 | + string spnInfo = Encoding.Unicode.GetString(result[0]); |
| 204 | + |
| 205 | + out_port = (int)port; |
| 206 | + |
| 207 | + return spnInfo; |
| 208 | + } |
| 209 | + |
| 210 | + private static bool IsKerberos() |
| 211 | + { |
| 212 | + return (DataTestUtility.AreConnStringsSetup() |
| 213 | + && DataTestUtility.IsNotLocalhost() |
| 214 | + && DataTestUtility.IsKerberosTest |
| 215 | + && DataTestUtility.IsNotAzureServer() |
| 216 | + && DataTestUtility.IsNotAzureSynapse()); |
| 217 | + } |
| 218 | + |
86 | 219 | private static bool IsBrowserAlive(string browserHostname) |
87 | 220 | { |
88 | 221 | const byte ClntUcastEx = 0x03; |
|
0 commit comments