Skip to content

Commit 10ccb20

Browse files
authored
Merge | TdsParserStateObject packet handling (#3603)
* netcore: teach TdsParserStateObject to clear the buffer after a password has been written to it * Port SetPacketData * Port references to SetPacketData, cleanup reference to SniPacketSetData overload * Port CreateAndSetAttentionPacket * Port CheckConnection * Port EnableSSL * Port Handle property and remaining variables * Remove remnants of CERs and tidy usings * Final cleanup - reference to SessionHandle, already-merged CER
1 parent 2655083 commit 10ccb20

File tree

8 files changed

+141
-229
lines changed

8 files changed

+141
-229
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
using System.Threading.Tasks;
1212
using Microsoft.Data.Common;
1313
using Microsoft.Data.ProviderBase;
14-
using Microsoft.Data.SqlClient.ManagedSni;
1514

1615
namespace Microsoft.Data.SqlClient
1716
{
@@ -57,10 +56,6 @@ protected TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCo
5756
// General methods //
5857
/////////////////////
5958

60-
internal abstract uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename);
61-
62-
internal abstract uint CheckConnection();
63-
6459
internal int DecrementPendingCallbacks(bool release)
6560
{
6661
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
@@ -215,8 +210,10 @@ private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
215210
return SniPacketGetData(packet, _inBuff, ref dataSize);
216211
}
217212

218-
private void SetBufferSecureStrings()
213+
private bool TrySetBufferSecureStrings()
219214
{
215+
bool mustClearBuffer = false;
216+
220217
if (_securePasswords != null)
221218
{
222219
for (int i = 0; i < _securePasswords.Length; i++)
@@ -240,6 +237,8 @@ private void SetBufferSecureStrings()
240237
}
241238
TdsParserStaticMethods.ObfuscatePassword(data);
242239
data.CopyTo(_outBuff, _securePasswordOffsetsInBuffer[i]);
240+
241+
mustClearBuffer = true;
243242
}
244243
finally
245244
{
@@ -248,6 +247,8 @@ private void SetBufferSecureStrings()
248247
}
249248
}
250249
}
250+
251+
return mustClearBuffer;
251252
}
252253

253254
public void ReadAsyncCallback(PacketHandle packet, uint error) =>
@@ -561,13 +562,7 @@ private Task SNIWritePacket(PacketHandle packet, out uint sniError, bool canAccu
561562
}
562563

563564
// Async operation completion may be delayed (success pending).
564-
try
565-
{
566-
}
567-
finally
568-
{
569-
sniError = WritePacket(packet, sync);
570-
}
565+
sniError = WritePacket(packet, sync);
571566

572567
if (sniError == TdsEnums.SNI_SUCCESS_IO_PENDING)
573568
{
@@ -730,17 +725,17 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
730725
}
731726
}
732727

733-
internal abstract PacketHandle CreateAndSetAttentionPacket();
734-
735-
internal abstract void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed);
736-
737728
private Task WriteSni(bool canAccumulate)
738729
{
739730
// Prepare packet, and write to packet.
740731
PacketHandle packet = GetResetWritePacket(_outBytesUsed);
732+
bool mustClearBuffer = TrySetBufferSecureStrings();
741733

742-
SetBufferSecureStrings();
743734
SetPacketData(packet, _outBuff, _outBytesUsed);
735+
if (mustClearBuffer)
736+
{
737+
_outBuff.AsSpan(0, _outBytesUsed).Clear();
738+
}
744739

745740
Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
746741
Task task = SNIWritePacket(packet, out _, canAccumulate, callerHasConnectionLock: true);

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
using System.Net;
99
using System.Runtime.InteropServices;
1010
using System.Security.Authentication;
11-
using System.Text;
1211
using System.Threading.Tasks;
1312
using Interop.Windows.Sni;
1413
using Microsoft.Data.Common;
@@ -308,10 +307,9 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint
308307

309308
internal override PacketHandle CreateAndSetAttentionPacket()
310309
{
311-
SNIHandle handle = Handle;
312-
SNIPacket attnPacket = new SNIPacket(handle);
310+
SNIPacket attnPacket = new SNIPacket(Handle);
313311
_sniAsyncAttnPacket = attnPacket;
314-
SetPacketData(PacketHandle.FromNativePacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN);
312+
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN);
315313
return PacketHandle.FromNativePacket(attnPacket);
316314
}
317315

@@ -399,28 +397,20 @@ internal override uint PostReadAsyncForMars(TdsParserStateObject physicalStateOb
399397
PacketHandle temp = default;
400398
uint error = TdsEnums.SNI_SUCCESS;
401399

402-
#if NETFRAMEWORK
403-
RuntimeHelpers.PrepareConstrainedRegions();
404-
#endif
405-
try
406-
{ }
407-
finally
408-
{
409-
IncrementPendingCallbacks();
410-
SessionHandle handle = SessionHandle;
411-
// we do not need to consider partial packets when making this read because we
412-
// expect this read to pend. a partial packet should not exist at setup of the
413-
// parser
414-
Debug.Assert(physicalStateObject.PartialPacket == null);
415-
temp = ReadAsync(handle, out error);
400+
IncrementPendingCallbacks();
401+
SessionHandle handle = SessionHandle;
402+
// we do not need to consider partial packets when making this read because we
403+
// expect this read to pend. a partial packet should not exist at setup of the
404+
// parser
405+
Debug.Assert(physicalStateObject.PartialPacket == null);
406+
temp = ReadAsync(handle, out error);
416407

417-
Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
408+
Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
418409

419-
if (temp.NativePointer != IntPtr.Zero)
420-
{
421-
// Be sure to release packet, otherwise it will be leaked by native.
422-
ReleasePacket(temp);
423-
}
410+
if (temp.NativePointer != IntPtr.Zero)
411+
{
412+
// Be sure to release packet, otherwise it will be leaked by native.
413+
ReleasePacket(temp);
424414
}
425415

426416
Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease");

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -942,19 +942,8 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ
942942
info |= TdsEnums.SNI_SSL_IGNORE_CHANNEL_BINDINGS;
943943
}
944944

945-
// Add SSL (Encryption) SNI provider.
946-
AuthProviderInfo authInfo = new AuthProviderInfo();
947-
authInfo.flags = info;
948-
authInfo.tlsFirst = encrypt == SqlConnectionEncryptOption.Strict;
949-
authInfo.certId = null;
950-
authInfo.certHash = false;
951-
authInfo.clientCertificateCallbackContext = IntPtr.Zero;
952-
authInfo.clientCertificateCallback = null;
953-
authInfo.serverCertFileName = string.IsNullOrEmpty(serverCertificateFilename) ? null : serverCertificateFilename;
954-
955945
Debug.Assert((_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0, "Client certificate authentication support has been removed");
956-
957-
error = SniNativeWrapper.SniAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);
946+
error = _physicalStateObj.EnableSsl(ref info, encrypt == SqlConnectionEncryptOption.Strict, serverCertificateFilename);
958947

959948
if (error != TdsEnums.SNI_SUCCESS)
960949
{

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,18 @@
44

55
using System;
66
using System.Buffers.Binary;
7-
using System.Collections.Generic;
87
using System.Diagnostics;
9-
using System.Runtime.CompilerServices;
10-
using System.Runtime.ConstrainedExecution;
118
using System.Runtime.InteropServices;
129
using System.Security;
1310
using System.Threading;
1411
using System.Threading.Tasks;
15-
using Interop.Windows.Sni;
1612
using Microsoft.Data.Common;
1713
using Microsoft.Data.ProviderBase;
1814

1915
namespace Microsoft.Data.SqlClient
2016
{
2117
internal partial class TdsParserStateObject
2218
{
23-
protected SNIHandle _sessionHandle = null; // the SNI handle we're to work on
24-
25-
// SNI variables // multiple resultsets in one batch.
26-
protected SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS
27-
internal SNIPacket _sniAsyncAttnPacket = null; // Packet to use to send Attn
28-
2919
// Used for blanking out password in trace.
3020
internal int _tracePasswordOffset = 0;
3121
internal int _tracePasswordLength = 0;
@@ -68,23 +58,10 @@ protected TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCo
6858
_lastSuccessfulIOTimer = parser._physicalStateObj._lastSuccessfulIOTimer;
6959
}
7060

71-
////////////////
72-
// Properties //
73-
////////////////
74-
internal SNIHandle Handle
75-
{
76-
get
77-
{
78-
return _sessionHandle;
79-
}
80-
}
81-
8261
/////////////////////
8362
// General methods //
8463
/////////////////////
8564

86-
internal uint CheckConnection() => SniNativeWrapper.SniCheckConnection(Handle);
87-
8865
internal int DecrementPendingCallbacks(bool release)
8966
{
9067
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
@@ -94,7 +71,7 @@ internal int DecrementPendingCallbacks(bool release)
9471

9572
// NOTE: TdsParserSessionPool may call DecrementPendingCallbacks on a TdsParserStateObject which is already disposed
9673
// This is not dangerous (since the stateObj is no longer in use), but we need to add a workaround in the assert for it
97-
Debug.Assert((remaining == -1 && _sessionHandle == null) || (0 <= remaining && remaining < 3), $"_pendingCallbacks values is invalid after decrementing: {remaining}");
74+
Debug.Assert((remaining == -1 && SessionHandle.IsNull) || (0 <= remaining && remaining < 3), $"_pendingCallbacks values is invalid after decrementing: {remaining}");
9875
return remaining;
9976
}
10077

@@ -121,11 +98,7 @@ internal bool ValidateSNIConnection()
12198
try
12299
{
123100
Interlocked.Increment(ref _readingCount);
124-
SNIHandle handle = Handle;
125-
if (handle != null)
126-
{
127-
error = SniNativeWrapper.SniCheckConnection(handle);
128-
}
101+
error = CheckConnection();
129102
}
130103
finally
131104
{
@@ -243,6 +216,47 @@ private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
243216
return SniPacketGetData(packet, _inBuff, ref dataSize);
244217
}
245218

219+
private bool TrySetBufferSecureStrings()
220+
{
221+
bool mustClearBuffer = false;
222+
223+
if (_securePasswords != null)
224+
{
225+
for (int i = 0; i < _securePasswords.Length; i++)
226+
{
227+
if (_securePasswords[i] != null)
228+
{
229+
IntPtr str = IntPtr.Zero;
230+
try
231+
{
232+
str = Marshal.SecureStringToBSTR(_securePasswords[i]);
233+
byte[] data = new byte[_securePasswords[i].Length * 2];
234+
Marshal.Copy(str, data, 0, _securePasswords[i].Length * 2);
235+
if (!BitConverter.IsLittleEndian)
236+
{
237+
Span<byte> span = data.AsSpan();
238+
for (int ii = 0; ii < _securePasswords[i].Length * 2; ii += 2)
239+
{
240+
short value = BinaryPrimitives.ReadInt16LittleEndian(span.Slice(ii));
241+
BinaryPrimitives.WriteInt16BigEndian(span.Slice(ii), value);
242+
}
243+
}
244+
TdsParserStaticMethods.ObfuscatePassword(data);
245+
data.CopyTo(_outBuff, _securePasswordOffsetsInBuffer[i]);
246+
247+
mustClearBuffer = true;
248+
}
249+
finally
250+
{
251+
Marshal.ZeroFreeBSTR(str);
252+
}
253+
}
254+
}
255+
}
256+
257+
return mustClearBuffer;
258+
}
259+
246260
public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error)
247261
{
248262
// Key never used.
@@ -717,20 +731,17 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
717731
}
718732
}
719733

720-
internal PacketHandle CreateAndSetAttentionPacket()
721-
{
722-
SNIPacket attnPacket = new SNIPacket(Handle);
723-
_sniAsyncAttnPacket = attnPacket;
724-
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);
725-
return PacketHandle.FromNativePacket(attnPacket);
726-
}
727-
728734
private Task WriteSni(bool canAccumulate)
729735
{
730736
// Prepare packet, and write to packet.
731737
PacketHandle packet = GetResetWritePacket(_outBytesUsed);
732-
SNIPacket nativePacket = packet.NativePacket;
733-
SniNativeWrapper.SniPacketSetData(nativePacket, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);
738+
bool mustClearBuffer = TrySetBufferSecureStrings();
739+
740+
SetPacketData(packet, _outBuff, _outBytesUsed);
741+
if (mustClearBuffer)
742+
{
743+
_outBuff.AsSpan(0, _outBytesUsed).Clear();
744+
}
734745

735746
Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
736747
Task task = SNIWritePacket(packet, out _, canAccumulate, callerHasConnectionLock: true);

0 commit comments

Comments
 (0)