33// See the LICENSE file in the project root for more information.
44
55using System ;
6+ using System . Collections . Generic ;
67using System . Diagnostics ;
8+ using System . Linq ;
79using System . Net ;
810using System . Net . Sockets ;
911using System . Text ;
12+ using System . Threading ;
1013using System . Threading . Tasks ;
1114
1215namespace Microsoft . Data . SqlClient . SNI
@@ -21,8 +24,11 @@ internal class SSRP
2124 /// </summary>
2225 /// <param name="browserHostName">SQL Sever Browser hostname</param>
2326 /// <param name="instanceName">instance name to find port number</param>
27+ /// <param name="timerExpire">Connection timer expiration</param>
28+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
29+ /// <param name="ipPreference">IP address preference</param>
2430 /// <returns>port number for given instance name</returns>
25- internal static int GetPortByInstanceName ( string browserHostName , string instanceName )
31+ internal static int GetPortByInstanceName ( string browserHostName , string instanceName , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
2632 {
2733 Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostName ) , "browserHostName should not be null, empty, or whitespace" ) ;
2834 Debug . Assert ( ! string . IsNullOrWhiteSpace ( instanceName ) , "instanceName should not be null, empty, or whitespace" ) ;
@@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
3238 byte [ ] responsePacket = null ;
3339 try
3440 {
35- responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , instanceInfoRequest ) ;
41+ responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , instanceInfoRequest , timerExpire , allIPsInParallel , ipPreference ) ;
3642 }
3743 catch ( SocketException se )
3844 {
@@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
8793 /// </summary>
8894 /// <param name="browserHostName">SQL Sever Browser hostname</param>
8995 /// <param name="instanceName">instance name to lookup DAC port</param>
96+ /// <param name="timerExpire">Connection timer expiration</param>
97+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
98+ /// <param name="ipPreference">IP address preference</param>
9099 /// <returns>DAC port for given instance name</returns>
91- internal static int GetDacPortByInstanceName ( string browserHostName , string instanceName )
100+ internal static int GetDacPortByInstanceName ( string browserHostName , string instanceName , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
92101 {
93102 Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostName ) , "browserHostName should not be null, empty, or whitespace" ) ;
94103 Debug . Assert ( ! string . IsNullOrWhiteSpace ( instanceName ) , "instanceName should not be null, empty, or whitespace" ) ;
95104
96105 byte [ ] dacPortInfoRequest = CreateDacPortInfoRequest ( instanceName ) ;
97- byte [ ] responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , dacPortInfoRequest ) ;
106+ byte [ ] responsePacket = SendUDPRequest ( browserHostName , SqlServerBrowserPort , dacPortInfoRequest , timerExpire , allIPsInParallel , ipPreference ) ;
98107
99108 const byte SvrResp = 0x05 ;
100109 const byte ProtocolVersion = 0x01 ;
@@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
131140 return requestPacket ;
132141 }
133142
143+ private class SsrpResult
144+ {
145+ public byte [ ] ResponsePacket ;
146+ public Exception Error ;
147+ }
148+
134149 /// <summary>
135150 /// Sends request to server, and receives response from server by UDP.
136151 /// </summary>
137152 /// <param name="browserHostname">UDP server hostname</param>
138153 /// <param name="port">UDP server port</param>
139154 /// <param name="requestPacket">request packet</param>
155+ /// <param name="timerExpire">Connection timer expiration</param>
156+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
157+ /// <param name="ipPreference">IP address preference</param>
140158 /// <returns>response packet from UDP server</returns>
141- private static byte [ ] SendUDPRequest ( string browserHostname , int port , byte [ ] requestPacket )
159+ private static byte [ ] SendUDPRequest ( string browserHostname , int port , byte [ ] requestPacket , long timerExpire , bool allIPsInParallel , SqlConnectionIPAddressPreference ipPreference )
142160 {
143161 using ( TrySNIEventScope . Create ( nameof ( SSRP ) ) )
144162 {
145163 Debug . Assert ( ! string . IsNullOrWhiteSpace ( browserHostname ) , "browserhostname should not be null, empty, or whitespace" ) ;
146164 Debug . Assert ( port >= 0 && port <= 65535 , "Invalid port" ) ;
147165 Debug . Assert ( requestPacket != null && requestPacket . Length > 0 , "requestPacket should not be null or 0-length array" ) ;
148166
149- const int sendTimeOutMs = 1000 ;
150- const int receiveTimeOutMs = 1000 ;
167+ bool isIpAddress = IPAddress . TryParse ( browserHostname , out IPAddress address ) ;
151168
152- IPAddress address = null ;
153- bool isIpAddress = IPAddress . TryParse ( browserHostname , out address ) ;
169+ TimeSpan ts = default ;
170+ // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
171+ // The infinite Timeout is a function of ConnectionString Timeout=0
172+ if ( long . MaxValue != timerExpire )
173+ {
174+ ts = DateTime . FromFileTime ( timerExpire ) - DateTime . Now ;
175+ ts = ts . Ticks < 0 ? TimeSpan . FromTicks ( 0 ) : ts ;
176+ }
154177
155- byte [ ] responsePacket = null ;
156- using ( UdpClient client = new UdpClient ( ! isIpAddress ? AddressFamily . InterNetwork : address . AddressFamily ) )
178+ IPAddress [ ] ipAddresses = null ;
179+ if ( ! isIpAddress )
180+ {
181+ Task < IPAddress [ ] > serverAddrTask = Dns . GetHostAddressesAsync ( browserHostname ) ;
182+ bool taskComplete ;
183+ try
184+ {
185+ taskComplete = serverAddrTask . Wait ( ts ) ;
186+ }
187+ catch ( AggregateException ae )
188+ {
189+ throw ae . InnerException ;
190+ }
191+
192+ // If DNS took too long, need to return instead of blocking
193+ if ( ! taskComplete )
194+ return null ;
195+
196+ ipAddresses = serverAddrTask . Result ;
197+ }
198+
199+ Debug . Assert ( ipAddresses . Length > 0 , "DNS should throw if zero addresses resolve" ) ;
200+
201+ switch ( ipPreference )
157202 {
158- Task < int > sendTask = client . SendAsync ( requestPacket , requestPacket . Length , browserHostname , port ) ;
203+ case SqlConnectionIPAddressPreference . IPv4First :
204+ {
205+ SsrpResult response4 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetwork ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
206+ if ( response4 != null && response4 . ResponsePacket != null )
207+ return response4 . ResponsePacket ;
208+
209+ SsrpResult response6 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetworkV6 ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
210+ if ( response6 != null && response6 . ResponsePacket != null )
211+ return response6 . ResponsePacket ;
212+
213+ // No responses so throw first error
214+ if ( response4 != null && response4 . Error != null )
215+ throw response4 . Error ;
216+ else if ( response6 != null && response6 . Error != null )
217+ throw response6 . Error ;
218+
219+ break ;
220+ }
221+ case SqlConnectionIPAddressPreference . IPv6First :
222+ {
223+ SsrpResult response6 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetworkV6 ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
224+ if ( response6 != null && response6 . ResponsePacket != null )
225+ return response6 . ResponsePacket ;
226+
227+ SsrpResult response4 = SendUDPRequest ( ipAddresses . Where ( i => i . AddressFamily == AddressFamily . InterNetwork ) . ToArray ( ) , port , requestPacket , allIPsInParallel ) ;
228+ if ( response4 != null && response4 . ResponsePacket != null )
229+ return response4 . ResponsePacket ;
230+
231+ // No responses so throw first error
232+ if ( response6 != null && response6 . Error != null )
233+ throw response6 . Error ;
234+ else if ( response4 != null && response4 . Error != null )
235+ throw response4 . Error ;
236+
237+ break ;
238+ }
239+ default :
240+ {
241+ SsrpResult response = SendUDPRequest ( ipAddresses , port , requestPacket , true ) ; // allIPsInParallel);
242+ if ( response != null && response . ResponsePacket != null )
243+ return response . ResponsePacket ;
244+ else if ( response != null && response . Error != null )
245+ throw response . Error ;
246+
247+ break ;
248+ }
249+ }
250+
251+ return null ;
252+ }
253+ }
254+
255+ /// <summary>
256+ /// Sends request to server, and receives response from server by UDP.
257+ /// </summary>
258+ /// <param name="ipAddresses">IP Addresses</param>
259+ /// <param name="port">UDP server port</param>
260+ /// <param name="requestPacket">request packet</param>
261+ /// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
262+ /// <returns>response packet from UDP server</returns>
263+ private static SsrpResult SendUDPRequest ( IPAddress [ ] ipAddresses , int port , byte [ ] requestPacket , bool allIPsInParallel )
264+ {
265+ if ( ipAddresses . Length == 0 )
266+ return null ;
267+
268+ if ( allIPsInParallel ) // Used for MultiSubnetFailover
269+ {
270+ List < Task < SsrpResult > > tasks = new ( ipAddresses . Length ) ;
271+ CancellationTokenSource cts = new CancellationTokenSource ( ) ;
272+ for ( int i = 0 ; i < ipAddresses . Length ; i ++ )
273+ {
274+ IPEndPoint endPoint = new IPEndPoint ( ipAddresses [ i ] , port ) ;
275+ tasks . Add ( Task . Factory . StartNew < SsrpResult > ( ( ) => SendUDPRequest ( endPoint , requestPacket ) ) ) ;
276+ }
277+
278+ List < Task < SsrpResult > > completedTasks = new ( ) ;
279+ while ( tasks . Count > 0 )
280+ {
281+ int first = Task . WaitAny ( tasks . ToArray ( ) ) ;
282+ if ( tasks [ first ] . Result . ResponsePacket != null )
283+ {
284+ cts . Cancel ( ) ;
285+ return tasks [ first ] . Result ;
286+ }
287+ else
288+ {
289+ completedTasks . Add ( tasks [ first ] ) ;
290+ tasks . Remove ( tasks [ first ] ) ;
291+ }
292+ }
293+
294+ Debug . Assert ( completedTasks . Count > 0 , "completedTasks should never be 0" ) ;
295+
296+ // All tasks failed. Return the error from the first failure.
297+ return completedTasks [ 0 ] . Result ;
298+ }
299+ else
300+ {
301+ // If not parallel, use the first IP address provided
302+ IPEndPoint endPoint = new IPEndPoint ( ipAddresses [ 0 ] , port ) ;
303+ return SendUDPRequest ( endPoint , requestPacket ) ;
304+ }
305+ }
306+
307+ private static SsrpResult SendUDPRequest ( IPEndPoint endPoint , byte [ ] requestPacket )
308+ {
309+ const int sendTimeOutMs = 1000 ;
310+ const int receiveTimeOutMs = 1000 ;
311+
312+ SsrpResult result = new ( ) ;
313+
314+ try
315+ {
316+ using ( UdpClient client = new UdpClient ( endPoint . AddressFamily ) )
317+ {
318+ Task < int > sendTask = client . SendAsync ( requestPacket , requestPacket . Length , endPoint ) ;
159319 Task < UdpReceiveResult > receiveTask = null ;
160-
320+
161321 SqlClientEventSource . Log . TrySNITraceEvent ( nameof ( SSRP ) , EventType . INFO , "Waiting for UDP Client to fetch Port info." ) ;
162322 if ( sendTask . Wait ( sendTimeOutMs ) && ( receiveTask = client . ReceiveAsync ( ) ) . Wait ( receiveTimeOutMs ) )
163323 {
164324 SqlClientEventSource . Log . TrySNITraceEvent ( nameof ( SSRP ) , EventType . INFO , "Received Port info from UDP Client." ) ;
165- responsePacket = receiveTask . Result . Buffer ;
325+ result . ResponsePacket = receiveTask . Result . Buffer ;
166326 }
167327 }
168-
169- return responsePacket ;
170328 }
329+ catch ( Exception e )
330+ {
331+ result . Error = e ;
332+ }
333+
334+ return result ;
171335 }
172336 }
173337}
0 commit comments