|
5 | 5 | using System; |
6 | 6 | using System.Collections.Generic; |
7 | 7 | using System.ComponentModel; |
| 8 | +using System.Diagnostics; |
8 | 9 | using System.IO; |
9 | 10 | using System.Net; |
10 | 11 | using System.Net.Security; |
|
14 | 15 | using System.Security.Cryptography.X509Certificates; |
15 | 16 | using System.Threading; |
16 | 17 | using System.Threading.Tasks; |
| 18 | +using Microsoft.Data.Common; |
17 | 19 |
|
18 | 20 | namespace Microsoft.Data.SqlClient.SNI |
19 | 21 | { |
@@ -347,149 +349,158 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i |
347 | 349 | availableSocket = connectTask.Result; |
348 | 350 | return availableSocket; |
349 | 351 | } |
| 352 | + |
| 353 | + /// <summary> |
| 354 | + /// Returns array of IP addresses for the given server name, sorted according to the given preference. |
| 355 | + /// </summary> |
| 356 | + /// <exception cref="ArgumentOutOfRangeException">Thrown when ipPreference is not supported</exception> |
| 357 | + private static IEnumerable<IPAddress> GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference) |
| 358 | + { |
| 359 | + IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); |
| 360 | + AddressFamily? prioritiesFamily = ipPreference switch |
| 361 | + { |
| 362 | + SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork, |
| 363 | + SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6, |
| 364 | + SqlConnectionIPAddressPreference.UsePlatformDefault => null, |
| 365 | + _ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference)) |
| 366 | + }; |
| 367 | + |
| 368 | + // Return addresses of the preferred family first |
| 369 | + if (prioritiesFamily != null) |
| 370 | + { |
| 371 | + foreach (IPAddress ipAddress in ipAddresses) |
| 372 | + { |
| 373 | + if (ipAddress.AddressFamily == prioritiesFamily) |
| 374 | + { |
| 375 | + yield return ipAddress; |
| 376 | + } |
| 377 | + } |
| 378 | + } |
350 | 379 |
|
| 380 | + // Return addresses of the other family |
| 381 | + foreach (IPAddress ipAddress in ipAddresses) |
| 382 | + { |
| 383 | + if (ipAddress.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6) |
| 384 | + { |
| 385 | + if (prioritiesFamily == null || ipAddress.AddressFamily != prioritiesFamily) |
| 386 | + { |
| 387 | + yield return ipAddress; |
| 388 | + } |
| 389 | + } |
| 390 | + } |
| 391 | + } |
| 392 | + |
351 | 393 | // Connect to server with hostName and port. |
352 | 394 | // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. |
353 | 395 | // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. |
354 | 396 | private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) |
355 | 397 | { |
356 | 398 | SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); |
357 | 399 |
|
358 | | - IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName); |
| 400 | + Stopwatch timeTaken = Stopwatch.StartNew(); |
359 | 401 |
|
360 | | - string IPv4String = null; |
361 | | - string IPv6String = null; |
| 402 | + IEnumerable<IPAddress> ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); |
362 | 403 |
|
363 | | - // Returning null socket is handled by the caller function. |
364 | | - if (ipAddresses == null || ipAddresses.Length == 0) |
| 404 | + foreach (IPAddress ipAddress in ipAddresses) |
365 | 405 | { |
366 | | - return null; |
367 | | - } |
368 | | - |
369 | | - Socket[] sockets = new Socket[ipAddresses.Length]; |
370 | | - AddressFamily[] preferedIPFamilies = new AddressFamily[2]; |
| 406 | + bool isSocketSelected = false; |
| 407 | + Socket socket = null; |
371 | 408 |
|
372 | | - if (ipPreference == SqlConnectionIPAddressPreference.IPv4First) |
373 | | - { |
374 | | - preferedIPFamilies[0] = AddressFamily.InterNetwork; |
375 | | - preferedIPFamilies[1] = AddressFamily.InterNetworkV6; |
376 | | - } |
377 | | - else if (ipPreference == SqlConnectionIPAddressPreference.IPv6First) |
378 | | - { |
379 | | - preferedIPFamilies[0] = AddressFamily.InterNetworkV6; |
380 | | - preferedIPFamilies[1] = AddressFamily.InterNetwork; |
381 | | - } |
382 | | - // else -> UsePlatformDefault |
383 | | - |
384 | | - CancellationTokenSource cts = null; |
385 | | - |
386 | | - if (!isInfiniteTimeout) |
387 | | - { |
388 | | - cts = new CancellationTokenSource(timeout); |
389 | | - cts.Token.Register(Cancel); |
390 | | - } |
391 | | - |
392 | | - Socket availableSocket = null; |
393 | | - try |
394 | | - { |
395 | | - // We go through the IP list twice. |
396 | | - // In the first traversal, we only try to connect with the preferedIPFamilies[0]. |
397 | | - // In the second traversal, we only try to connect with the preferedIPFamilies[1]. |
398 | | - // For UsePlatformDefault preference, we do traversal once. |
399 | | - for (int i = 0; i < preferedIPFamilies.Length; ++i) |
| 409 | + try |
400 | 410 | { |
401 | | - for (int n = 0; n < ipAddresses.Length; n++) |
| 411 | + socket = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp) |
| 412 | + { |
| 413 | + Blocking = isInfiniteTimeout |
| 414 | + }; |
| 415 | + |
| 416 | + // enable keep-alive on socket |
| 417 | + SetKeepAliveValues(ref socket); |
| 418 | + |
| 419 | + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, |
| 420 | + "Connecting to IP address {0} and port {1} using {2} address family. Is infinite timeout: {3}", |
| 421 | + ipAddress, |
| 422 | + port, |
| 423 | + ipAddress.AddressFamily, |
| 424 | + isInfiniteTimeout); |
| 425 | + |
| 426 | + bool isConnected; |
| 427 | + try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select |
402 | 428 | { |
403 | | - IPAddress ipAddress = ipAddresses[n]; |
404 | | - try |
| 429 | + socket.Connect(ipAddress, port); |
| 430 | + if (!isInfiniteTimeout) |
405 | 431 | { |
406 | | - if (ipAddress != null) |
407 | | - { |
408 | | - if (ipAddress.AddressFamily != preferedIPFamilies[i] && ipPreference != SqlConnectionIPAddressPreference.UsePlatformDefault) |
409 | | - { |
410 | | - continue; |
411 | | - } |
| 432 | + throw SQL.SocketDidNotThrow(); |
| 433 | + } |
| 434 | + |
| 435 | + isConnected = true; |
| 436 | + } |
| 437 | + catch (SocketException socketException) when (!isInfiniteTimeout && |
| 438 | + socketException.SocketErrorCode == |
| 439 | + SocketError.WouldBlock) |
| 440 | + { |
| 441 | + // https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118 |
| 442 | + // Socket.Select is used because it supports timeouts, while Socket.Connect does not |
412 | 443 |
|
413 | | - sockets[n] = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp); |
| 444 | + List<Socket> checkReadLst; List<Socket> checkWriteLst; List<Socket> checkErrorLst; |
414 | 445 |
|
415 | | - // enable keep-alive on socket |
416 | | - SetKeepAliveValues(ref sockets[n]); |
| 446 | + // Repeating Socket.Select several times if our timeout is greater |
| 447 | + // than int.MaxValue microseconds because of |
| 448 | + // https://github.com/dotnet/SqlClient/pull/1029#issuecomment-875364044 |
| 449 | + // which states that Socket.Select can't handle timeouts greater than int.MaxValue microseconds |
| 450 | + do |
| 451 | + { |
| 452 | + TimeSpan timeLeft = timeout - timeTaken.Elapsed; |
417 | 453 |
|
418 | | - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connecting to IP address {0} and port {1} using {2} address family.", |
419 | | - args0: ipAddress, |
420 | | - args1: port, |
421 | | - args2: ipAddress.AddressFamily); |
422 | | - sockets[n].Connect(ipAddress, port); |
423 | | - if (sockets[n] != null) // sockets[n] can be null if cancel callback is executed during connect() |
424 | | - { |
425 | | - if (sockets[n].Connected) |
426 | | - { |
427 | | - availableSocket = sockets[n]; |
428 | | - if (ipAddress.AddressFamily == AddressFamily.InterNetwork) |
429 | | - { |
430 | | - IPv4String = ipAddress.ToString(); |
431 | | - } |
432 | | - else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) |
433 | | - { |
434 | | - IPv6String = ipAddress.ToString(); |
435 | | - } |
| 454 | + if (timeLeft <= TimeSpan.Zero) |
| 455 | + return null; |
436 | 456 |
|
437 | | - break; |
438 | | - } |
439 | | - else |
440 | | - { |
441 | | - sockets[n].Dispose(); |
442 | | - sockets[n] = null; |
443 | | - } |
444 | | - } |
445 | | - } |
446 | | - } |
447 | | - catch (Exception e) |
448 | | - { |
449 | | - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message); |
450 | | - SqlClientEventSource.Log.TryAdvancedTraceEvent($"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}"); |
451 | | - } |
452 | | - } |
| 457 | + int socketSelectTimeout = |
| 458 | + checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000)); |
453 | 459 |
|
454 | | - // If we have already got a valid Socket, or the platform default was prefered |
455 | | - // we won't do the second traversal. |
456 | | - if (availableSocket is not null || ipPreference == SqlConnectionIPAddressPreference.UsePlatformDefault) |
457 | | - { |
458 | | - break; |
459 | | - } |
460 | | - } |
461 | | - } |
462 | | - finally |
463 | | - { |
464 | | - cts?.Dispose(); |
465 | | - } |
| 460 | + checkReadLst = new List<Socket>(1) { socket }; |
| 461 | + checkWriteLst = new List<Socket>(1) { socket }; |
| 462 | + checkErrorLst = new List<Socket>(1) { socket }; |
466 | 463 |
|
467 | | - // we only record the ip we can connect with successfully. |
468 | | - if (IPv4String != null || IPv6String != null) |
469 | | - { |
470 | | - pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString()); |
471 | | - } |
| 464 | + Socket.Select(checkReadLst, checkWriteLst, checkErrorLst, socketSelectTimeout); |
| 465 | + // nothing selected means timeout |
| 466 | + } while (checkReadLst.Count == 0 && checkWriteLst.Count == 0 && checkErrorLst.Count == 0); |
472 | 467 |
|
473 | | - return availableSocket; |
| 468 | + // workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538 |
| 469 | + isConnected = socket.Connected && checkErrorLst.Count == 0; |
| 470 | + } |
474 | 471 |
|
475 | | - void Cancel() |
476 | | - { |
477 | | - for (int i = 0; i < sockets.Length; ++i) |
478 | | - { |
479 | | - try |
| 472 | + if (isConnected) |
480 | 473 | { |
481 | | - if (sockets[i] != null && !sockets[i].Connected) |
| 474 | + socket.Blocking = true; |
| 475 | + string iPv4String = null; |
| 476 | + string iPv6String = null; |
| 477 | + if (socket.AddressFamily == AddressFamily.InterNetwork) |
482 | 478 | { |
483 | | - sockets[i].Dispose(); |
484 | | - sockets[i] = null; |
| 479 | + iPv4String = ipAddress.ToString(); |
485 | 480 | } |
| 481 | + else |
| 482 | + { |
| 483 | + iPv6String = ipAddress.ToString(); |
| 484 | + } |
| 485 | + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, iPv4String, iPv6String, port.ToString()); |
| 486 | + isSocketSelected = true; |
| 487 | + return socket; |
486 | 488 | } |
487 | | - catch (Exception e) |
488 | | - { |
489 | | - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message); |
490 | | - } |
| 489 | + } |
| 490 | + catch (SocketException e) |
| 491 | + { |
| 492 | + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message); |
| 493 | + SqlClientEventSource.Log.TryAdvancedTraceEvent( |
| 494 | + $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}"); |
| 495 | + } |
| 496 | + finally |
| 497 | + { |
| 498 | + if (!isSocketSelected) |
| 499 | + socket?.Dispose(); |
491 | 500 | } |
492 | 501 | } |
| 502 | + |
| 503 | + return null; |
493 | 504 | } |
494 | 505 |
|
495 | 506 | private static Task<Socket> ParallelConnectAsync(IPAddress[] serverAddresses, int port) |
|
0 commit comments