Skip to content

Commit 80c56f0

Browse files
committed
Rewrite ByteArrayReader API to use ReadOnlySpan<byte>.
Signed-off-by: Bradley Grainger <[email protected]>
1 parent 7030baf commit 80c56f0

File tree

10 files changed

+80
-71
lines changed

10 files changed

+80
-71
lines changed

src/MySqlConnector/Core/ResultSet.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,15 @@ public async Task<ResultSet> ReadResultSetHeaderAsync(IOBehavior ioBehavior)
9090
}
9191
else
9292
{
93-
var reader = new ByteArrayReader(payload.ArraySegment);
94-
var columnCount = (int) reader.ReadLengthEncodedInteger();
95-
if (reader.BytesRemaining != 0)
96-
throw new MySqlException("Unexpected data at end of column_count packet; see https://github.com/mysql-net/MySqlConnector/issues/324");
93+
int ReadColumnCount(ArraySegment<byte> arraySegment)
94+
{
95+
var reader = new ByteArrayReader(arraySegment);
96+
var columnCount_ = (int) reader.ReadLengthEncodedInteger();
97+
if (reader.BytesRemaining != 0)
98+
throw new MySqlException("Unexpected data at end of column_count packet; see https://github.com/mysql-net/MySqlConnector/issues/324");
99+
return columnCount_;
100+
}
101+
var columnCount = ReadColumnCount(payload.ArraySegment);
97102

98103
// reserve adequate space to hold a copy of all column definitions (but note that this can be resized below if we guess too small)
99104
Array.Resize(ref m_columnDefinitionPayloads, columnCount * 96);
@@ -251,7 +256,7 @@ Row ScanRowAsyncRemainder(PayloadData payload, Row row_)
251256
{
252257
var length = reader.ReadLengthEncodedIntegerOrNull();
253258
m_dataLengths[column] = length == -1 ? 0 : length;
254-
m_dataOffsets[column] = length == -1 ? -1 : reader.Offset;
259+
m_dataOffsets[column] = length == -1 ? -1 : reader.Offset + payload.ArraySegment.Offset;
255260
reader.Offset += m_dataLengths[column];
256261
}
257262

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Buffers.Text;
23
using System.Data;
34
using System.Diagnostics;
45
using System.Globalization;
@@ -1020,16 +1021,16 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation
10201021
}
10211022

10221023
// first (and only) row
1023-
int? connectionId = default;
1024-
string serverVersion = null;
10251024
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
1026-
var reader = new ByteArrayReader(payload.ArraySegment);
1027-
var length = reader.ReadLengthEncodedIntegerOrNull();
1028-
if (length != -1)
1029-
connectionId = int.Parse(Encoding.UTF8.GetString(reader.ReadByteArraySegment(length)), CultureInfo.InvariantCulture);
1030-
length = reader.ReadLengthEncodedIntegerOrNull();
1031-
if (length != -1)
1032-
serverVersion = Encoding.UTF8.GetString(reader.ReadByteArraySegment(length));
1025+
void ReadRow(ArraySegment<byte> arraySegment, out int? connectionId_, out string serverVersion_)
1026+
{
1027+
var reader = new ByteArrayReader(arraySegment);
1028+
var length = reader.ReadLengthEncodedIntegerOrNull();
1029+
connectionId_ = (length != -1 && Utf8Parser.TryParse(reader.ReadByteString(length), out int id, out _)) ? id : default(int?);
1030+
length = reader.ReadLengthEncodedIntegerOrNull();
1031+
serverVersion_ = length != -1 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null;
1032+
}
1033+
ReadRow(payload.ArraySegment, out var connectionId, out var serverVersion);
10331034

10341035
// OK/EOF payload
10351036
payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);

src/MySqlConnector/MySqlConnector.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<AssemblyName>MySqlConnector</AssemblyName>
1010
<PackageId>MySqlConnector</PackageId>
1111
<PackageTags>mysql;mysqlconnector;async;ado.net;database;netcore</PackageTags>
12+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
1213
</PropertyGroup>
1314

1415
<ItemGroup Condition=" '$(TargetFramework)' == 'net45' OR '$(TargetFramework)' == 'net461' ">

src/MySqlConnector/Protocol/Payloads/AuthenticationMethodSwitchRequestPayload.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Text;
22
using MySqlConnector.Protocol.Serialization;
3+
using MySqlConnector.Utilities;
34

45
namespace MySqlConnector.Protocol.Payloads
56
{
@@ -26,7 +27,7 @@ public static AuthenticationMethodSwitchRequestPayload Create(PayloadData payloa
2627
else
2728
{
2829
name = Encoding.UTF8.GetString(reader.ReadNullTerminatedByteString());
29-
data = reader.ReadByteArray(reader.BytesRemaining);
30+
data = reader.ReadByteString(reader.BytesRemaining).ToArray();
3031
}
3132
return new AuthenticationMethodSwitchRequestPayload(name, data);
3233
}

src/MySqlConnector/Protocol/Payloads/AuthenticationMoreDataPayload.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public static AuthenticationMoreDataPayload Create(PayloadData payload)
1212
{
1313
var reader = new ByteArrayReader(payload.ArraySegment);
1414
reader.ReadByte(Signature);
15-
return new AuthenticationMoreDataPayload(reader.ReadByteArray(reader.BytesRemaining));
15+
return new AuthenticationMoreDataPayload(reader.ReadByteString(reader.BytesRemaining).ToArray());
1616
}
1717

1818
private AuthenticationMoreDataPayload(byte[] data) => Data = data;

src/MySqlConnector/Protocol/Payloads/ErrorPayload.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ public static ErrorPayload Create(PayloadData payload)
2020
reader.ReadByte(Signature);
2121

2222
var errorCode = reader.ReadUInt16();
23-
var stateMarker = Encoding.ASCII.GetString(reader.ReadByteArraySegment(1));
23+
var stateMarker = Encoding.ASCII.GetString(reader.ReadByteString(1));
2424
string state, message;
2525
if (stateMarker == "#")
2626
{
27-
state = Encoding.ASCII.GetString(reader.ReadByteArraySegment(5));
28-
message = Encoding.UTF8.GetString(reader.ReadByteArraySegment(payload.ArraySegment.Count - 9));
27+
state = Encoding.ASCII.GetString(reader.ReadByteString(5));
28+
message = Encoding.UTF8.GetString(reader.ReadByteString(payload.ArraySegment.Count - 9));
2929
}
3030
else
3131
{
3232
state = "HY000";
33-
message = stateMarker + Encoding.UTF8.GetString(reader.ReadByteArraySegment(payload.ArraySegment.Count - 4));
33+
message = stateMarker + Encoding.UTF8.GetString(reader.ReadByteString(payload.ArraySegment.Count - 4));
3434
}
3535
return new ErrorPayload(errorCode, state, message);
3636
}

src/MySqlConnector/Protocol/Payloads/InitialHandshakePayload.cs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Text;
33
using MySqlConnector.Protocol.Serialization;
4+
using MySqlConnector.Utilities;
45

56
namespace MySqlConnector.Protocol.Payloads
67
{
@@ -19,7 +20,7 @@ public static InitialHandshakePayload Create(PayloadData payload)
1920
var serverVersion = reader.ReadNullTerminatedByteString();
2021
var connectionId = reader.ReadInt32();
2122
byte[] authPluginData = null;
22-
var authPluginData1 = reader.ReadByteArraySegment(8);
23+
var authPluginData1 = reader.ReadByteString(8);
2324
string authPluginName = null;
2425
reader.ReadByte(0);
2526
var protocolCapabilities = (ProtocolCapabilities) reader.ReadUInt16();
@@ -30,28 +31,24 @@ public static InitialHandshakePayload Create(PayloadData payload)
3031
var capabilityFlagsHigh = reader.ReadUInt16();
3132
protocolCapabilities |= (ProtocolCapabilities) (capabilityFlagsHigh << 16);
3233
var authPluginDataLength = reader.ReadByte();
33-
var unused = reader.ReadByteArraySegment(10);
34+
var unused = reader.ReadByteString(10);
3435
if ((protocolCapabilities & ProtocolCapabilities.SecureConnection) != 0)
3536
{
36-
var authPluginData2 = reader.ReadByteArraySegment(Math.Max(13, authPluginDataLength - 8));
37-
var concatenated = new byte[authPluginData1.Count + authPluginData2.Count];
38-
Buffer.BlockCopy(authPluginData1.Array, authPluginData1.Offset, concatenated, 0, authPluginData1.Count);
39-
Buffer.BlockCopy(authPluginData2.Array, authPluginData2.Offset, concatenated, authPluginData1.Count, authPluginData2.Count);
40-
authPluginData = concatenated;
37+
var authPluginData2 = reader.ReadByteString(Math.Max(13, authPluginDataLength - 8));
38+
authPluginData = new byte[authPluginData1.Length + authPluginData2.Length];
39+
authPluginData1.CopyTo(authPluginData);
40+
authPluginData2.CopyTo(new Span<byte>(authPluginData).Slice(authPluginData1.Length));
4141
}
4242
if ((protocolCapabilities & ProtocolCapabilities.PluginAuth) != 0)
4343
authPluginName = Encoding.UTF8.GetString(reader.ReadNullOrEofTerminatedByteString());
4444
}
4545
if (authPluginData == null)
46-
{
47-
authPluginData = new byte[authPluginData1.Count];
48-
Buffer.BlockCopy(authPluginData1.Array, authPluginData1.Offset, authPluginData, 0, authPluginData1.Count);
49-
}
46+
authPluginData = authPluginData1.ToArray();
5047

5148
if (reader.BytesRemaining != 0)
5249
throw new FormatException("Extra bytes at end of payload.");
5350

54-
return new InitialHandshakePayload(protocolCapabilities, serverVersion, connectionId, authPluginData, authPluginName);
51+
return new InitialHandshakePayload(protocolCapabilities, serverVersion.ToArray(), connectionId, authPluginData, authPluginName);
5552
}
5653

5754
private InitialHandshakePayload(ProtocolCapabilities protocolCapabilities, byte[] serverVersion, int connectionId, byte[] authPluginData, string authPluginName)

src/MySqlConnector/Protocol/Payloads/LocalInfilePayload.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public static LocalInfilePayload Create(PayloadData payload)
1515
{
1616
var reader = new ByteArrayReader(payload.ArraySegment);
1717
reader.ReadByte(Signature);
18-
var fileName = Encoding.UTF8.GetString(reader.ReadByteArraySegment(reader.BytesRemaining));
18+
var fileName = Encoding.UTF8.GetString(reader.ReadByteString(reader.BytesRemaining));
1919
return new LocalInfilePayload(fileName);
2020
}
2121

src/MySqlConnector/Protocol/Serialization/ByteArrayReader.cs

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
using System;
2+
using System.Buffers.Binary;
23
using MySqlConnector.Utilities;
34

45
namespace MySqlConnector.Protocol.Serialization
56
{
6-
internal struct ByteArrayReader
7+
internal ref struct ByteArrayReader
78
{
8-
public ByteArrayReader(byte[] buffer, int offset, int length)
9+
public ByteArrayReader(ReadOnlySpan<byte> buffer)
910
{
10-
m_buffer = buffer ?? throw new ArgumentNullException(nameof(buffer));
11-
m_offset = offset >= 0 ? offset : throw new ArgumentOutOfRangeException(nameof(offset));
12-
m_maxOffset = offset + length <= m_buffer.Length ? offset + length : throw new ArgumentOutOfRangeException(nameof(length));
11+
m_buffer = buffer;
12+
m_offset = 0;
13+
m_maxOffset = buffer.Length;
1314
}
1415

1516
public ByteArrayReader(ArraySegment<byte> arraySegment)
16-
: this(arraySegment.Array, arraySegment.Offset, arraySegment.Count)
17+
: this(arraySegment.AsSpan())
1718
{
1819
}
1920

@@ -38,31 +39,31 @@ public void ReadByte(byte value)
3839
public short ReadInt16()
3940
{
4041
VerifyRead(2);
41-
var result = BitConverter.ToInt16(m_buffer, m_offset);
42+
var result = BinaryPrimitives.ReadInt16LittleEndian(m_buffer.Slice(m_offset));
4243
m_offset += 2;
4344
return result;
4445
}
4546

4647
public ushort ReadUInt16()
4748
{
4849
VerifyRead(2);
49-
var result = BitConverter.ToUInt16(m_buffer, m_offset);
50+
var result = BinaryPrimitives.ReadUInt16LittleEndian(m_buffer.Slice(m_offset));
5051
m_offset += 2;
5152
return result;
5253
}
5354

5455
public int ReadInt32()
5556
{
5657
VerifyRead(4);
57-
var result = BitConverter.ToInt32(m_buffer, m_offset);
58+
var result = BinaryPrimitives.ReadInt32LittleEndian(m_buffer.Slice(m_offset));
5859
m_offset += 4;
5960
return result;
6061
}
6162

6263
public uint ReadUInt32()
6364
{
6465
VerifyRead(4);
65-
var result = BitConverter.ToUInt32(m_buffer, m_offset);
66+
var result = BinaryPrimitives.ReadUInt32LittleEndian(m_buffer.Slice(m_offset));
6667
m_offset += 4;
6768
return result;
6869
}
@@ -91,46 +92,34 @@ public ulong ReadFixedLengthUInt64(int length)
9192
return result;
9293
}
9394

94-
// TODO: Span<byte>
95-
public byte[] ReadNullTerminatedByteString()
95+
public ReadOnlySpan<byte> ReadNullTerminatedByteString()
9696
{
9797
int index = m_offset;
9898
while (index < m_maxOffset && m_buffer[index] != 0)
9999
index++;
100100
if (index == m_maxOffset)
101101
throw new FormatException("Read past end of buffer looking for NUL.");
102-
byte[] substring = new byte[index - m_offset];
103-
Buffer.BlockCopy(m_buffer, m_offset, substring, 0, substring.Length);
102+
var substring = m_buffer.Slice(m_offset, index - m_offset);
104103
m_offset = index + 1;
105104
return substring;
106105
}
107106

108-
public byte[] ReadNullOrEofTerminatedByteString()
107+
public ReadOnlySpan<byte> ReadNullOrEofTerminatedByteString()
109108
{
110109
int index = m_offset;
111110
while (index < m_maxOffset && m_buffer[index] != 0)
112111
index++;
113-
byte[] substring = new byte[index - m_offset];
114-
Buffer.BlockCopy(m_buffer, m_offset, substring, 0, substring.Length);
112+
var substring = m_buffer.Slice(m_offset, index - m_offset);
115113
if (index < m_maxOffset && m_buffer[index] == 0)
116114
index++;
117115
m_offset = index;
118116
return substring;
119117
}
120118

121-
public byte[] ReadByteArray(int length)
119+
public ReadOnlySpan<byte> ReadByteString(int length)
122120
{
123121
VerifyRead(length);
124-
var result = new byte[length];
125-
Buffer.BlockCopy(m_buffer, m_offset, result, 0, result.Length);
126-
m_offset += length;
127-
return result;
128-
}
129-
130-
public ArraySegment<byte> ReadByteArraySegment(int length)
131-
{
132-
VerifyRead(length);
133-
var result = new ArraySegment<byte>(m_buffer, m_offset, length);
122+
var result = m_buffer.Slice(m_offset, length);
134123
m_offset += length;
135124
return result;
136125
}
@@ -166,10 +155,10 @@ public int ReadLengthEncodedIntegerOrNull()
166155
return checked((int) ReadLengthEncodedInteger());
167156
}
168157

169-
public ArraySegment<byte> ReadLengthEncodedByteString()
158+
public ReadOnlySpan<byte> ReadLengthEncodedByteString()
170159
{
171160
var length = checked((int) ReadLengthEncodedInteger());
172-
var result = new ArraySegment<byte>(m_buffer, m_offset, length);
161+
var result = m_buffer.Slice(m_offset, length);
173162
m_offset += length;
174163
return result;
175164
}
@@ -182,7 +171,7 @@ private void VerifyRead(int length)
182171
throw new InvalidOperationException("Read past end of buffer.");
183172
}
184173

185-
readonly byte[] m_buffer;
174+
readonly ReadOnlySpan<byte> m_buffer;
186175
readonly int m_maxOffset;
187176
int m_offset;
188177
}

src/MySqlConnector/Utilities/Utility.cs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ public static string FormatInvariant(this string format, params object[] args) =
3434
public static string GetString(this Encoding encoding, ArraySegment<byte> arraySegment) =>
3535
encoding.GetString(arraySegment.Array, arraySegment.Offset, arraySegment.Count);
3636

37+
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> span)
38+
{
39+
if (span.Length == 0)
40+
return "";
41+
#if NET45
42+
return encoding.GetString(span.ToArray());
43+
#else
44+
unsafe
45+
{
46+
fixed (byte* ptr = span)
47+
return encoding.GetString(ptr, span.Length);
48+
}
49+
#endif
50+
}
51+
3752
/// <summary>
3853
/// Loads a RSA public key from a PEM string. Taken from <a href="https://stackoverflow.com/a/32243171/23633">Stack Overflow</a>.
3954
/// </summary>
@@ -172,21 +187,21 @@ public static byte[] ArraySlice(byte[] input, int offset, int length)
172187
}
173188

174189
/// <summary>
175-
/// Finds the next index of <paramref name="pattern"/> in <paramref name="array"/>, starting at index <paramref name="offset"/>.
190+
/// Finds the next index of <paramref name="pattern"/> in <paramref name="data"/>, starting at index <paramref name="offset"/>.
176191
/// </summary>
177-
/// <param name="array">The array to search.</param>
192+
/// <param name="data">The array to search.</param>
178193
/// <param name="offset">The offset at which to start searching.</param>
179-
/// <param name="pattern">The pattern to find in <paramref name="array"/>.</param>
180-
/// <returns>The offset of <paramref name="pattern"/> within <paramref name="array"/>, or <c>-1</c> if <paramref name="pattern"/> was not found.</returns>
181-
public static int FindNextIndex(byte[] array, int offset, byte[] pattern)
194+
/// <param name="pattern">The pattern to find in <paramref name="data"/>.</param>
195+
/// <returns>The offset of <paramref name="pattern"/> within <paramref name="data"/>, or <c>-1</c> if <paramref name="pattern"/> was not found.</returns>
196+
public static int FindNextIndex(ReadOnlySpan<byte> data, int offset, ReadOnlySpan<byte> pattern)
182197
{
183-
var limit = array.Length - pattern.Length;
198+
var limit = data.Length - pattern.Length;
184199
for (var start = offset; start <= limit; start++)
185200
{
186201
var i = 0;
187202
for (; i < pattern.Length; i++)
188203
{
189-
if (array[start + i] != pattern[i])
204+
if (data[start + i] != pattern[i])
190205
break;
191206
}
192207
if (i == pattern.Length)

0 commit comments

Comments
 (0)