Skip to content

Commit f478be5

Browse files
authored
SqlClient-826 Missed synchronization (#1029)
1 parent 4cd43be commit f478be5

File tree

5 files changed

+154
-116
lines changed

5 files changed

+154
-116
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Common/src/System/Net/InternalException.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@
44

55
namespace System.Net
66
{
7+
[Serializable]
78
internal class InternalException : Exception
89
{
9-
internal InternalException()
10+
public InternalException() : this("InternalException thrown.")
1011
{
11-
NetEventSource.Fail(this, "InternalException thrown.");
12+
}
13+
14+
public InternalException(string message) : this(message, null)
15+
{
16+
}
17+
18+
public InternalException(string message, Exception innerException) : base(message, innerException)
19+
{
20+
NetEventSource.Fail(this, message);
1221
}
1322
}
1423
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs

Lines changed: 125 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.ComponentModel;
8+
using System.Diagnostics;
89
using System.IO;
910
using System.Net;
1011
using System.Net.Security;
@@ -14,6 +15,7 @@
1415
using System.Security.Cryptography.X509Certificates;
1516
using System.Threading;
1617
using System.Threading.Tasks;
18+
using Microsoft.Data.Common;
1719

1820
namespace Microsoft.Data.SqlClient.SNI
1921
{
@@ -347,149 +349,158 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
347349
availableSocket = connectTask.Result;
348350
return availableSocket;
349351
}
352+
353+
/// <summary>
354+
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
355+
/// </summary>
356+
/// <exception cref="ArgumentOutOfRangeException">Thrown when ipPreference is not supported</exception>
357+
private static IEnumerable<IPAddress> GetHostAddressesSortedByPreference(string serverName, SqlConnectionIPAddressPreference ipPreference)
358+
{
359+
IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
360+
AddressFamily? prioritiesFamily = ipPreference switch
361+
{
362+
SqlConnectionIPAddressPreference.IPv4First => AddressFamily.InterNetwork,
363+
SqlConnectionIPAddressPreference.IPv6First => AddressFamily.InterNetworkV6,
364+
SqlConnectionIPAddressPreference.UsePlatformDefault => null,
365+
_ => throw ADP.NotSupportedEnumerationValue(typeof(SqlConnectionIPAddressPreference), ipPreference.ToString(), nameof(GetHostAddressesSortedByPreference))
366+
};
367+
368+
// Return addresses of the preferred family first
369+
if (prioritiesFamily != null)
370+
{
371+
foreach (IPAddress ipAddress in ipAddresses)
372+
{
373+
if (ipAddress.AddressFamily == prioritiesFamily)
374+
{
375+
yield return ipAddress;
376+
}
377+
}
378+
}
350379

380+
// Return addresses of the other family
381+
foreach (IPAddress ipAddress in ipAddresses)
382+
{
383+
if (ipAddress.AddressFamily is AddressFamily.InterNetwork or AddressFamily.InterNetworkV6)
384+
{
385+
if (prioritiesFamily == null || ipAddress.AddressFamily != prioritiesFamily)
386+
{
387+
yield return ipAddress;
388+
}
389+
}
390+
}
391+
}
392+
351393
// Connect to server with hostName and port.
352394
// The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point.
353395
// Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server.
354396
private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
355397
{
356398
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));
357399

358-
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName);
400+
Stopwatch timeTaken = Stopwatch.StartNew();
359401

360-
string IPv4String = null;
361-
string IPv6String = null;
402+
IEnumerable<IPAddress> ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference);
362403

363-
// Returning null socket is handled by the caller function.
364-
if (ipAddresses == null || ipAddresses.Length == 0)
404+
foreach (IPAddress ipAddress in ipAddresses)
365405
{
366-
return null;
367-
}
368-
369-
Socket[] sockets = new Socket[ipAddresses.Length];
370-
AddressFamily[] preferedIPFamilies = new AddressFamily[2];
406+
bool isSocketSelected = false;
407+
Socket socket = null;
371408

372-
if (ipPreference == SqlConnectionIPAddressPreference.IPv4First)
373-
{
374-
preferedIPFamilies[0] = AddressFamily.InterNetwork;
375-
preferedIPFamilies[1] = AddressFamily.InterNetworkV6;
376-
}
377-
else if (ipPreference == SqlConnectionIPAddressPreference.IPv6First)
378-
{
379-
preferedIPFamilies[0] = AddressFamily.InterNetworkV6;
380-
preferedIPFamilies[1] = AddressFamily.InterNetwork;
381-
}
382-
// else -> UsePlatformDefault
383-
384-
CancellationTokenSource cts = null;
385-
386-
if (!isInfiniteTimeout)
387-
{
388-
cts = new CancellationTokenSource(timeout);
389-
cts.Token.Register(Cancel);
390-
}
391-
392-
Socket availableSocket = null;
393-
try
394-
{
395-
// We go through the IP list twice.
396-
// In the first traversal, we only try to connect with the preferedIPFamilies[0].
397-
// In the second traversal, we only try to connect with the preferedIPFamilies[1].
398-
// For UsePlatformDefault preference, we do traversal once.
399-
for (int i = 0; i < preferedIPFamilies.Length; ++i)
409+
try
400410
{
401-
for (int n = 0; n < ipAddresses.Length; n++)
411+
socket = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp)
412+
{
413+
Blocking = isInfiniteTimeout
414+
};
415+
416+
// enable keep-alive on socket
417+
SetKeepAliveValues(ref socket);
418+
419+
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO,
420+
"Connecting to IP address {0} and port {1} using {2} address family. Is infinite timeout: {3}",
421+
ipAddress,
422+
port,
423+
ipAddress.AddressFamily,
424+
isInfiniteTimeout);
425+
426+
bool isConnected;
427+
try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select
402428
{
403-
IPAddress ipAddress = ipAddresses[n];
404-
try
429+
socket.Connect(ipAddress, port);
430+
if (!isInfiniteTimeout)
405431
{
406-
if (ipAddress != null)
407-
{
408-
if (ipAddress.AddressFamily != preferedIPFamilies[i] && ipPreference != SqlConnectionIPAddressPreference.UsePlatformDefault)
409-
{
410-
continue;
411-
}
432+
throw SQL.SocketDidNotThrow();
433+
}
434+
435+
isConnected = true;
436+
}
437+
catch (SocketException socketException) when (!isInfiniteTimeout &&
438+
socketException.SocketErrorCode ==
439+
SocketError.WouldBlock)
440+
{
441+
// https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118
442+
// Socket.Select is used because it supports timeouts, while Socket.Connect does not
412443

413-
sockets[n] = new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
444+
List<Socket> checkReadLst; List<Socket> checkWriteLst; List<Socket> checkErrorLst;
414445

415-
// enable keep-alive on socket
416-
SetKeepAliveValues(ref sockets[n]);
446+
// Repeating Socket.Select several times if our timeout is greater
447+
// than int.MaxValue microseconds because of
448+
// https://github.com/dotnet/SqlClient/pull/1029#issuecomment-875364044
449+
// which states that Socket.Select can't handle timeouts greater than int.MaxValue microseconds
450+
do
451+
{
452+
TimeSpan timeLeft = timeout - timeTaken.Elapsed;
417453

418-
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connecting to IP address {0} and port {1} using {2} address family.",
419-
args0: ipAddress,
420-
args1: port,
421-
args2: ipAddress.AddressFamily);
422-
sockets[n].Connect(ipAddress, port);
423-
if (sockets[n] != null) // sockets[n] can be null if cancel callback is executed during connect()
424-
{
425-
if (sockets[n].Connected)
426-
{
427-
availableSocket = sockets[n];
428-
if (ipAddress.AddressFamily == AddressFamily.InterNetwork)
429-
{
430-
IPv4String = ipAddress.ToString();
431-
}
432-
else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
433-
{
434-
IPv6String = ipAddress.ToString();
435-
}
454+
if (timeLeft <= TimeSpan.Zero)
455+
return null;
436456

437-
break;
438-
}
439-
else
440-
{
441-
sockets[n].Dispose();
442-
sockets[n] = null;
443-
}
444-
}
445-
}
446-
}
447-
catch (Exception e)
448-
{
449-
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
450-
SqlClientEventSource.Log.TryAdvancedTraceEvent($"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
451-
}
452-
}
457+
int socketSelectTimeout =
458+
checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000));
453459

454-
// If we have already got a valid Socket, or the platform default was prefered
455-
// we won't do the second traversal.
456-
if (availableSocket is not null || ipPreference == SqlConnectionIPAddressPreference.UsePlatformDefault)
457-
{
458-
break;
459-
}
460-
}
461-
}
462-
finally
463-
{
464-
cts?.Dispose();
465-
}
460+
checkReadLst = new List<Socket>(1) { socket };
461+
checkWriteLst = new List<Socket>(1) { socket };
462+
checkErrorLst = new List<Socket>(1) { socket };
466463

467-
// we only record the ip we can connect with successfully.
468-
if (IPv4String != null || IPv6String != null)
469-
{
470-
pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString());
471-
}
464+
Socket.Select(checkReadLst, checkWriteLst, checkErrorLst, socketSelectTimeout);
465+
// nothing selected means timeout
466+
} while (checkReadLst.Count == 0 && checkWriteLst.Count == 0 && checkErrorLst.Count == 0);
472467

473-
return availableSocket;
468+
// workaround: false positive socket.Connected on linux: https://github.com/dotnet/runtime/issues/55538
469+
isConnected = socket.Connected && checkErrorLst.Count == 0;
470+
}
474471

475-
void Cancel()
476-
{
477-
for (int i = 0; i < sockets.Length; ++i)
478-
{
479-
try
472+
if (isConnected)
480473
{
481-
if (sockets[i] != null && !sockets[i].Connected)
474+
socket.Blocking = true;
475+
string iPv4String = null;
476+
string iPv6String = null;
477+
if (socket.AddressFamily == AddressFamily.InterNetwork)
482478
{
483-
sockets[i].Dispose();
484-
sockets[i] = null;
479+
iPv4String = ipAddress.ToString();
485480
}
481+
else
482+
{
483+
iPv6String = ipAddress.ToString();
484+
}
485+
pendingDNSInfo = new SQLDNSInfo(cachedFQDN, iPv4String, iPv6String, port.ToString());
486+
isSocketSelected = true;
487+
return socket;
486488
}
487-
catch (Exception e)
488-
{
489-
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
490-
}
489+
}
490+
catch (SocketException e)
491+
{
492+
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
493+
SqlClientEventSource.Log.TryAdvancedTraceEvent(
494+
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
495+
}
496+
finally
497+
{
498+
if (!isSocketSelected)
499+
socket?.Dispose();
491500
}
492501
}
502+
503+
return null;
493504
}
494505

495506
private static Task<Socket> ParallelConnectAsync(IPAddress[] serverAddresses, int port)

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
using System.Data;
99
using System.Diagnostics;
1010
using System.Globalization;
11+
using System.Net;
12+
using System.Net.Sockets;
1113
using System.Reflection;
1214
using System.Runtime.CompilerServices;
1315
using System.Runtime.ExceptionServices;
@@ -377,6 +379,10 @@ internal static Exception SynchronousCallMayNotPend()
377379
{
378380
return new Exception(StringsHelper.GetString(Strings.Sql_InternalError));
379381
}
382+
internal static Exception SocketDidNotThrow()
383+
{
384+
return new InternalException(StringsHelper.GetString(Strings.SQL_SocketDidNotThrow, nameof(SocketException), nameof(SocketError.WouldBlock)));
385+
}
380386
internal static Exception ConnectionLockedForBcpEvent()
381387
{
382388
return ADP.InvalidOperation(StringsHelper.GetString(Strings.SQL_ConnectionLockedForBcpEvent));

src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.Designer.cs

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/Microsoft.Data.SqlClient/netcore/src/Resources/Strings.resx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,4 +1941,7 @@
19411941
<data name="SQL_TDS8_NotSupported_Netstandard2.0" xml:space="preserve">
19421942
<value>Encrypt=Strict is not supported when targeting .NET Standard 2.0. Use .NET Standard 2.1, .NET Framework, or .NET.</value>
19431943
</data>
1944+
<data name="SQL_SocketDidNotThrow" xml:space="preserve">
1945+
<value>Socket did not throw expected '{0}' with error code '{1}'.</value>
1946+
</data>
19441947
</root>

0 commit comments

Comments
 (0)