Skip to content

Commit 2b3765f

Browse files
committed
Overhaul implementation and add tests.
Signed-off-by: Bradley Grainger <[email protected]>
1 parent ffe19aa commit 2b3765f

File tree

4 files changed

+231
-24
lines changed

4 files changed

+231
-24
lines changed

src/MySqlConnector/MySqlCommand.cs

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using System.Buffers;
21
using System.Diagnostics.CodeAnalysis;
2+
using System.Runtime.InteropServices;
33
using Microsoft.Extensions.Logging;
44
using MySqlConnector.Core;
55
using MySqlConnector.Protocol;
@@ -368,34 +368,50 @@ internal async ValueTask<MySqlDataReader> ExecuteReaderNoResetTimeoutAsync(Comma
368368
{
369369
for (var i = 0; i < Parameters.Count; i++)
370370
{
371-
if (Parameters[i].Value is Stream stream)
371+
if (Parameters[i].Value is Stream stream and not MemoryStream)
372372
{
373-
var writer = new ByteBufferWriter();
374-
writer.Write((byte) CommandKind.StatementSendLongData);
375-
writer.Write(statements.Statements[0].StatementId);
376-
writer.Write((ushort) i);
377-
378-
var buffer = ArrayPool<byte>.Shared.Rent(ProtocolUtility.MaxPacketSize);
379-
var dataLength = ProtocolUtility.MaxPacketSize - 7;
380-
381-
var bytesRead = stream.Read(buffer, 0, dataLength);
382-
writer.Write(buffer.AsSpan(0, bytesRead));
383-
384-
await Connection!.Session.SendAsync(writer.ToPayloadData(), ioBehavior, cancellationToken).ConfigureAwait(false);
385-
386-
if (bytesRead == dataLength)
373+
// send almost-full packets, but don't send exactly ProtocolUtility.MaxPacketSize bytes in one payload (as that's ambiguous to whether another packet follows).
374+
const int maxDataSize = 16_000_000;
375+
int totalBytesRead;
376+
while (true)
387377
{
388-
dataLength = ProtocolUtility.MaxPacketSize;
389-
378+
// write seven-byte COM_STMT_SEND_LONG_DATA header: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html
379+
var writer = new ByteBufferWriter(maxDataSize);
380+
writer.Write((byte) CommandKind.StatementSendLongData);
381+
writer.Write(statements.Statements[0].StatementId);
382+
writer.Write((ushort) i);
383+
384+
// keep reading from the stream until we've filled the buffer to send
385+
#if NET7_0_OR_GREATER
386+
if (ioBehavior == IOBehavior.Synchronous)
387+
totalBytesRead = stream.ReadAtLeast(writer.GetSpan(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false);
388+
else
389+
totalBytesRead = await stream.ReadAtLeastAsync(writer.GetMemory(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
390+
writer.Advance(totalBytesRead);
391+
#else
392+
totalBytesRead = 0;
393+
int bytesRead;
390394
do
391395
{
392-
bytesRead = stream.Read(buffer, 0, dataLength);
396+
var sizeToRead = maxDataSize - totalBytesRead;
397+
ReadOnlyMemory<byte> bufferMemory = writer.GetMemory(sizeToRead);
398+
if (!MemoryMarshal.TryGetArray(bufferMemory, out var arraySegment))
399+
throw new InvalidOperationException("Failed to get array segment from buffer memory.");
400+
if (ioBehavior == IOBehavior.Synchronous)
401+
bytesRead = stream.Read(arraySegment.Array!, arraySegment.Offset, sizeToRead);
402+
else
403+
bytesRead = await stream.ReadAsync(arraySegment.Array!, arraySegment.Offset, sizeToRead, cancellationToken).ConfigureAwait(false);
404+
totalBytesRead += bytesRead;
405+
writer.Advance(bytesRead);
406+
} while (bytesRead > 0);
407+
#endif
408+
409+
if (totalBytesRead == 0)
410+
break;
393411

394-
writer = new ByteBufferWriter();
395-
writer.Write(buffer.AsSpan(0, bytesRead));
396-
await Connection!.Session.SendReplyAsync(writer.ToPayloadData(), ioBehavior, cancellationToken).ConfigureAwait(false);
397-
}
398-
while (bytesRead == dataLength);
412+
// send StatementSendLongData; MySQL Server will keep appending the sent data to the parameter value
413+
using var payload = writer.ToPayloadData();
414+
await Connection!.Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
399415
}
400416
}
401417
}

src/MySqlConnector/MySqlParameter.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ internal void AppendSqlString(ByteBufferWriter writer, StatementPreparerOptions
331331
Debug.Assert(index == length, "index == length");
332332
writer.Advance(index);
333333
}
334+
else if (Value is Stream)
335+
{
336+
throw new NotSupportedException($"Parameter type {Value.GetType().Name} can only be used after calling MySqlCommand.Prepare.");
337+
}
334338
else if (Value is bool boolValue)
335339
{
336340
writer.Write(boolValue ? "true"u8 : "false"u8);
@@ -721,6 +725,13 @@ private void AppendBinary(ByteBufferWriter writer, object value, StatementPrepar
721725
writer.WriteLengthEncodedInteger(unchecked((ulong) geometry.ValueSpan.Length));
722726
writer.Write(geometry.ValueSpan);
723727
}
728+
else if (value is MemoryStream memoryStream)
729+
{
730+
if (!memoryStream.TryGetBuffer(out var streamBuffer))
731+
streamBuffer = new ArraySegment<byte>(memoryStream.ToArray());
732+
writer.WriteLengthEncodedInteger(unchecked((ulong) streamBuffer.Count));
733+
writer.Write(streamBuffer);
734+
}
724735
else if (value is Stream)
725736
{
726737
// do nothing; this will be sent via CommandKind.StatementSendLongData

tests/IntegrationTests/ChunkStream.cs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
namespace IntegrationTests;
2+
3+
internal sealed class ChunkStream : Stream
4+
{
5+
public ChunkStream(int dataLength, int chunkLength)
6+
{
7+
m_dataLength = dataLength;
8+
m_chunkLength = chunkLength;
9+
m_position = 0;
10+
}
11+
12+
public override bool CanRead => true;
13+
public override bool CanSeek => false;
14+
public override bool CanWrite => false;
15+
public override long Length => m_dataLength;
16+
public override long Position
17+
{
18+
get => m_position;
19+
set => throw new NotSupportedException();
20+
}
21+
22+
public override int Read(byte[] buffer, int offset, int count)
23+
{
24+
if (buffer is null)
25+
throw new ArgumentNullException(nameof(buffer));
26+
if (offset < 0 || offset > buffer.Length)
27+
throw new ArgumentOutOfRangeException(nameof(offset));
28+
if (count < 0 || offset + count > buffer.Length)
29+
throw new ArgumentOutOfRangeException(nameof(count));
30+
31+
return Read(buffer.AsSpan(offset, count));
32+
}
33+
34+
public
35+
#if NETSTANDARD2_1 || NETCOREAPP2_1_OR_GREATER
36+
override
37+
#endif
38+
int Read(Span<byte> buffer)
39+
{
40+
if (m_position >= m_dataLength)
41+
return 0;
42+
43+
// Read at most chunkLength bytes
44+
var bytesToRead = Math.Min(buffer.Length, Math.Min(m_chunkLength, m_dataLength - m_position));
45+
46+
// Fill with dummy data (repeating pattern based on position)
47+
for (var i = 0; i < bytesToRead; i++)
48+
{
49+
buffer[i] = (byte) ((m_position + i) % 256);
50+
}
51+
52+
m_position += bytesToRead;
53+
return bytesToRead;
54+
}
55+
56+
public override int ReadByte()
57+
{
58+
Span<byte> buffer = stackalloc byte[1];
59+
var bytesRead = Read(buffer);
60+
return bytesRead == 0 ? -1 : buffer[0];
61+
}
62+
63+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
64+
{
65+
if (buffer is null)
66+
throw new ArgumentNullException(nameof(buffer));
67+
if (offset < 0 || offset > buffer.Length)
68+
throw new ArgumentOutOfRangeException(nameof(offset));
69+
if (count < 0 || offset + count > buffer.Length)
70+
throw new ArgumentOutOfRangeException(nameof(count));
71+
72+
if (cancellationToken.IsCancellationRequested)
73+
return Task.FromCanceled<int>(cancellationToken);
74+
75+
try
76+
{
77+
return Task.FromResult(Read(buffer.AsSpan(offset, count)));
78+
}
79+
catch (Exception ex)
80+
{
81+
return Task.FromException<int>(ex);
82+
}
83+
}
84+
85+
public
86+
#if NETSTANDARD2_1 || NETCOREAPP2_1_OR_GREATER
87+
override
88+
#endif
89+
ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
90+
{
91+
if (cancellationToken.IsCancellationRequested)
92+
return new(Task.FromCanceled<int>(cancellationToken));
93+
94+
try
95+
{
96+
return new(Read(buffer.Span));
97+
}
98+
catch (Exception ex)
99+
{
100+
return new(Task.FromException<int>(ex));
101+
}
102+
}
103+
104+
public override void Write(byte[] buffer, int offset, int count) =>
105+
throw new NotSupportedException();
106+
107+
public override void WriteByte(byte value) =>
108+
throw new NotSupportedException();
109+
110+
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
111+
throw new NotSupportedException();
112+
113+
public override void SetLength(long value) =>
114+
throw new NotSupportedException();
115+
116+
public override long Seek(long offset, SeekOrigin origin) =>
117+
throw new NotSupportedException();
118+
119+
public override void Flush() =>
120+
throw new NotSupportedException();
121+
122+
public override Task FlushAsync(CancellationToken cancellationToken) =>
123+
throw new NotSupportedException();
124+
125+
private readonly int m_dataLength;
126+
private readonly int m_chunkLength;
127+
private int m_position;
128+
}

tests/IntegrationTests/InsertTests.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,58 @@ public void InsertMySqlDecimalAsDecimal(bool prepare)
477477
}
478478
#endif
479479

480+
[Theory]
481+
[InlineData(1_000_000, 1024, true)]
482+
[InlineData(1_000_000, 1024, false)]
483+
[InlineData(1_000_000, int.MaxValue, true)]
484+
[InlineData(1_000_000, int.MaxValue, false)]
485+
[InlineData(0xff_fff8, 299593, true)]
486+
[InlineData(0xff_fff8, 299593, false)]
487+
[InlineData(0xff_fff8, 300000, true)]
488+
[InlineData(0xff_fff8, 300000, false)]
489+
[InlineData(0xff_fff8, int.MaxValue, true)]
490+
[InlineData(0xff_fff8, int.MaxValue, false)]
491+
[InlineData(0xff_fff9, int.MaxValue, true)]
492+
[InlineData(0xff_fff9, int.MaxValue, false)]
493+
[InlineData(0x1ff_fff0, 299593, true)]
494+
[InlineData(0x1ff_fff0, 299593, false)]
495+
[InlineData(0x1ff_fff0, 300000, true)]
496+
[InlineData(0x1ff_fff0, 300000, false)]
497+
[InlineData(15_999_999, int.MaxValue, true)]
498+
[InlineData(15_999_999, int.MaxValue, false)]
499+
[InlineData(16_000_000, int.MaxValue, true)]
500+
[InlineData(16_000_000, int.MaxValue, false)]
501+
[InlineData(16_000_001, int.MaxValue, true)]
502+
[InlineData(16_000_001, int.MaxValue, false)]
503+
[InlineData(31_999_999, 999_999, true)]
504+
[InlineData(31_999_999, 1_000_000, false)]
505+
[InlineData(32_000_000, 1_000_001, true)]
506+
[InlineData(32_000_000, 1_000_002, false)]
507+
[InlineData(32_000_001, 1_000_003, true)]
508+
[InlineData(32_000_001, 1_000_004, false)]
509+
public async Task SendLongData(int dataLength, int chunkLength, bool isAsync)
510+
{
511+
using MySqlConnection connection = new MySqlConnection(AppConfig.ConnectionString);
512+
connection.Open();
513+
connection.Execute("""
514+
drop table if exists insert_mysql_long_data;
515+
create table insert_mysql_long_data(rowid integer not null primary key auto_increment, value longblob);
516+
""");
517+
518+
using var chunkStream = new ChunkStream(dataLength, chunkLength);
519+
520+
using var writeCommand = new MySqlCommand("insert into insert_mysql_long_data(value) values(@value);", connection);
521+
writeCommand.Parameters.AddWithValue("@value", chunkStream);
522+
writeCommand.Prepare();
523+
if (isAsync)
524+
await writeCommand.ExecuteNonQueryAsync().ConfigureAwait(true);
525+
else
526+
writeCommand.ExecuteNonQuery();
527+
528+
using var readLengthCommand = new MySqlCommand("select length(value) from insert_mysql_long_data order by rowid;", connection);
529+
Assert.Equal(chunkStream.Length, readLengthCommand.ExecuteScalar());
530+
}
531+
480532
[Theory]
481533
[InlineData(false)]
482534
[InlineData(true)]

0 commit comments

Comments
 (0)