diff --git a/docs/content/troubleshooting/parameter-types.md b/docs/content/troubleshooting/parameter-types.md index 82cd238d0..ae5123b0b 100644 --- a/docs/content/troubleshooting/parameter-types.md +++ b/docs/content/troubleshooting/parameter-types.md @@ -38,7 +38,8 @@ In some cases, this may be as simple as calling `.ToString()` or `.ToString(Cult * .NET primitives: `bool`, `byte`, `char`, `double`, `float`, `int`, `long`, `sbyte`, `short`, `uint`, `ulong`, `ushort` * Common types: `BigInteger`, `DateOnly`, `DateTime`, `DateTimeOffset`, `decimal`, `enum`, `Guid`, `string`, `TimeOnly`, `TimeSpan` -* BLOB types: `ArraySegment`, `byte[]`, `Memory`, `ReadOnlyMemory` +* BLOB types: `ArraySegment`, `byte[]`, `Memory`, `ReadOnlyMemory`. `MemoryStream` + * NOTE: `System.IO.Stream` and derived types (other than `MemoryStream`) are supported only when `MySqlCommand.Prepare` is called first. The data in the `Stream` will be streamed to the database server as binary data. * Vector types: `float[]`, `Memory`, `ReadOnlyMemory` * String types: `Memory`, `ReadOnlyMemory`, `StringBuilder` * Custom MySQL types: `MySqlDateTime`, `MySqlDecimal`, `MySqlGeometry` diff --git a/src/MySqlConnector/Core/CommandExecutor.cs b/src/MySqlConnector/Core/CommandExecutor.cs index a2581334a..732fec43c 100644 --- a/src/MySqlConnector/Core/CommandExecutor.cs +++ b/src/MySqlConnector/Core/CommandExecutor.cs @@ -39,6 +39,8 @@ public static async ValueTask ExecuteReaderAsync(CommandListPos } } + await payloadCreator.SendCommandPrologueAsync(connection, commandListPosition, ioBehavior, cancellationToken).ConfigureAwait(false); + var writer = new ByteBufferWriter(); //// cachedProcedures will be non-null if there is a stored procedure, which is also the only time it will be read if (!payloadCreator.WriteQueryCommand(ref commandListPosition, cachedProcedures!, writer, false)) diff --git a/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs b/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs index 20e0717f6..c87f59080 100644 --- a/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs +++ b/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs @@ -8,6 +8,9 @@ internal sealed class ConcatenatedCommandPayloadCreator : ICommandPayloadCreator { public static ICommandPayloadCreator Instance { get; } = new ConcatenatedCommandPayloadCreator(); + public ValueTask SendCommandPrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken) => + default; + public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDictionary cachedProcedures, ByteBufferWriter writer, bool appendSemicolon) { if (commandListPosition.CommandIndex == commandListPosition.CommandCount) diff --git a/src/MySqlConnector/Core/ICommandPayloadCreator.cs b/src/MySqlConnector/Core/ICommandPayloadCreator.cs index c6e32d04a..14d6fb9c2 100644 --- a/src/MySqlConnector/Core/ICommandPayloadCreator.cs +++ b/src/MySqlConnector/Core/ICommandPayloadCreator.cs @@ -7,6 +7,16 @@ namespace MySqlConnector.Core; /// internal interface ICommandPayloadCreator { + /// + /// Sends any data that is required to be sent to the server before the current command in the command list. + /// + /// The to which the data will be written. + /// The giving the current command and current prepared statement. + /// The IO behavior. + /// A cancellation token to cancel the asynchronous operation. + /// A representing the asynchronous operation or a completed if no data needed to be sent. + ValueTask SendCommandPrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken); + /// /// Writes the payload for an "execute query" command to . /// diff --git a/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs b/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs index 4cc7b12db..122b4ccc8 100644 --- a/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs +++ b/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs @@ -1,3 +1,6 @@ +using System.Buffers; +using System.Buffers.Binary; +using System.Runtime.InteropServices; using MySqlConnector.Logging; using MySqlConnector.Protocol; using MySqlConnector.Protocol.Serialization; @@ -12,6 +15,81 @@ internal sealed class SingleCommandPayloadCreator : ICommandPayloadCreator // with this as the first column name, the result set will be treated as 'out' parameters for the previous command. public static string OutParameterSentinelColumnName => "\uE001\b\x0B"; + public async ValueTask SendCommandPrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken) + { + // get the current command and check for prepared statements + var command = commandListPosition.CommandAt(commandListPosition.CommandIndex); + var preparedStatements = commandListPosition.PreparedStatements ?? command.TryGetPreparedStatements(); + if (preparedStatements is not null) + { + // get the current prepared statement; WriteQueryCommand will advance this + var preparedStatement = preparedStatements.Statements[commandListPosition.PreparedStatementIndex]; + if (preparedStatement.Parameters is { } parameters) + { + byte[]? buffer = null; + try + { + // check each parameter + for (var i = 0; i < parameters.Length; i++) + { + // look up this parameter in the command's parameter collection and check if it is a Stream + // NOTE: full parameter checks will be performed (and throw any necessary exceptions) in WritePreparedStatement + var parameterName = preparedStatement.Statement.NormalizedParameterNames![i]; + var parameterIndex = parameterName is not null ? (command.RawParameters?.UnsafeIndexOf(parameterName) ?? -1) : preparedStatement.Statement.ParameterIndexes[i]; + if (command.RawParameters is { } rawParameters && parameterIndex >= 0 && parameterIndex < rawParameters.Count && rawParameters[parameterIndex] is { Value: Stream stream and not MemoryStream }) + { + // seven-byte COM_STMT_SEND_LONG_DATA header: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html + const int packetHeaderLength = 7; + + // send almost-full packets, but don't send exactly ProtocolUtility.MaxPacketSize bytes in one payload (as that's ambiguous to whether another packet follows). + const int maxDataSize = 16_000_000; + int totalBytesRead; + while (true) + { + buffer ??= ArrayPool.Shared.Rent(packetHeaderLength + maxDataSize); + buffer[0] = (byte) CommandKind.StatementSendLongData; + BinaryPrimitives.WriteInt32LittleEndian(buffer.AsSpan(1), preparedStatement.StatementId); + BinaryPrimitives.WriteUInt16LittleEndian(buffer.AsSpan(5), (ushort) i); + + // keep reading from the stream until we've filled the buffer to send +#if NET7_0_OR_GREATER + if (ioBehavior == IOBehavior.Synchronous) + totalBytesRead = stream.ReadAtLeast(buffer.AsSpan(packetHeaderLength, maxDataSize), maxDataSize, throwOnEndOfStream: false); + else + totalBytesRead = await stream.ReadAtLeastAsync(buffer.AsMemory(packetHeaderLength, maxDataSize), maxDataSize, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false); +#else + totalBytesRead = 0; + int bytesRead; + do + { + var sizeToRead = maxDataSize - totalBytesRead; + if (ioBehavior == IOBehavior.Synchronous) + bytesRead = stream.Read(buffer, packetHeaderLength + totalBytesRead, sizeToRead); + else + bytesRead = await stream.ReadAsync(buffer, packetHeaderLength + totalBytesRead, sizeToRead, cancellationToken).ConfigureAwait(false); + totalBytesRead += bytesRead; + } while (bytesRead > 0 && totalBytesRead < maxDataSize); +#endif + + if (totalBytesRead == 0) + break; + + // send StatementSendLongData; MySQL Server will keep appending the sent data to the parameter value + using var payload = new PayloadData(buffer.AsMemory(0, packetHeaderLength + totalBytesRead), isPooled: false); + await connection.Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); + } + } + } + } + finally + { + if (buffer is not null) + ArrayPool.Shared.Return(buffer); + } + } + } + } + public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDictionary cachedProcedures, ByteBufferWriter writer, bool appendSemicolon) { if (commandListPosition.CommandIndex == commandListPosition.CommandCount) @@ -58,6 +136,7 @@ public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDict { commandListPosition.CommandIndex++; commandListPosition.PreparedStatementIndex = 0; + commandListPosition.PreparedStatements = null; } } return true; diff --git a/src/MySqlConnector/MySqlDataReader.cs b/src/MySqlConnector/MySqlDataReader.cs index 98e07fdd0..fc34ca365 100644 --- a/src/MySqlConnector/MySqlDataReader.cs +++ b/src/MySqlConnector/MySqlDataReader.cs @@ -74,8 +74,10 @@ internal async Task NextResultAsync(IOBehavior ioBehavior, CancellationTok Command = m_commandListPosition.CommandAt(m_commandListPosition.CommandIndex); using (Command.CancellableCommand.RegisterCancel(cancellationToken)) { + await m_payloadCreator!.SendCommandPrologueAsync(Command.Connection!, m_commandListPosition, ioBehavior, cancellationToken).ConfigureAwait(false); + var writer = new ByteBufferWriter(); - if (!Command.Connection!.Session.IsCancelingQuery && m_payloadCreator!.WriteQueryCommand(ref m_commandListPosition, m_cachedProcedures!, writer, false)) + if (!Command.Connection!.Session.IsCancelingQuery && m_payloadCreator.WriteQueryCommand(ref m_commandListPosition, m_cachedProcedures!, writer, false)) { using var payload = writer.ToPayloadData(); await Command.Connection.Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); diff --git a/src/MySqlConnector/MySqlParameter.cs b/src/MySqlConnector/MySqlParameter.cs index 3913c48ce..b217eb814 100644 --- a/src/MySqlConnector/MySqlParameter.cs +++ b/src/MySqlConnector/MySqlParameter.cs @@ -331,6 +331,10 @@ internal void AppendSqlString(ByteBufferWriter writer, StatementPreparerOptions Debug.Assert(index == length, "index == length"); writer.Advance(index); } + else if (Value is Stream) + { + throw new NotSupportedException($"Parameter type {Value.GetType().Name} can only be used after calling MySqlCommand.Prepare."); + } else if (Value is bool boolValue) { writer.Write(boolValue ? "true"u8 : "false"u8); @@ -728,6 +732,10 @@ private void AppendBinary(ByteBufferWriter writer, object value, StatementPrepar writer.WriteLengthEncodedInteger(unchecked((ulong) streamBuffer.Count)); writer.Write(streamBuffer); } + else if (value is Stream) + { + // do nothing; this will be sent via CommandKind.StatementSendLongData + } else if (value is float floatValue) { #if NET5_0_OR_GREATER diff --git a/src/MySqlConnector/Protocol/CommandKind.cs b/src/MySqlConnector/Protocol/CommandKind.cs index 6e384a733..0d546721c 100644 --- a/src/MySqlConnector/Protocol/CommandKind.cs +++ b/src/MySqlConnector/Protocol/CommandKind.cs @@ -9,5 +9,6 @@ internal enum CommandKind ChangeUser = 17, StatementPrepare = 22, StatementExecute = 23, + StatementSendLongData = 24, ResetConnection = 31, } diff --git a/tests/IntegrationTests/ChunkStream.cs b/tests/IntegrationTests/ChunkStream.cs new file mode 100644 index 000000000..9c44d3ede --- /dev/null +++ b/tests/IntegrationTests/ChunkStream.cs @@ -0,0 +1,130 @@ +namespace IntegrationTests; + +internal sealed class ChunkStream : Stream +{ + public ChunkStream(byte[] data, int chunkLength) + { + if (data is null) + throw new ArgumentNullException(nameof(data)); + if (chunkLength <= 0) + throw new ArgumentOutOfRangeException(nameof(chunkLength)); + + m_data = data; + m_chunkLength = chunkLength; + m_position = 0; + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => m_data.Length; + public override long Position + { + get => m_position; + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (buffer is null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0 || offset > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0 || offset + count > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(count)); + + return Read(buffer.AsSpan(offset, count)); + } + + public +#if NETSTANDARD2_1 || NETCOREAPP2_1_OR_GREATER + override +#endif + int Read(Span buffer) + { + if (m_position >= m_data.Length) + return 0; + + // Read at most chunkLength bytes + var bytesToRead = Math.Min(buffer.Length, Math.Min(m_chunkLength, m_data.Length - m_position)); + + // Copy data from the actual data array + m_data.AsSpan(m_position, bytesToRead).CopyTo(buffer); + + m_position += bytesToRead; + return bytesToRead; + } + + public override int ReadByte() + { + Span buffer = stackalloc byte[1]; + var bytesRead = Read(buffer); + return bytesRead == 0 ? -1 : buffer[0]; + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (buffer is null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0 || offset > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0 || offset + count > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(count)); + + if (cancellationToken.IsCancellationRequested) + return Task.FromCanceled(cancellationToken); + + try + { + return Task.FromResult(Read(buffer.AsSpan(offset, count))); + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + public +#if NETSTANDARD2_1 || NETCOREAPP2_1_OR_GREATER + override +#endif + ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + return new(Task.FromCanceled(cancellationToken)); + + try + { + return new(Read(buffer.Span)); + } + catch (Exception ex) + { + return new(Task.FromException(ex)); + } + } + + public override void Write(byte[] buffer, int offset, int count) => + throw new NotSupportedException(); + + public override void WriteByte(byte value) => + throw new NotSupportedException(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + throw new NotSupportedException(); + + public override void SetLength(long value) => + throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => + throw new NotSupportedException(); + + public override void Flush() => + throw new NotSupportedException(); + + public override Task FlushAsync(CancellationToken cancellationToken) => + throw new NotSupportedException(); + + private readonly byte[] m_data; + private readonly int m_chunkLength; + private int m_position; +} diff --git a/tests/IntegrationTests/InsertTests.cs b/tests/IntegrationTests/InsertTests.cs index 0557a19e6..63191eacd 100644 --- a/tests/IntegrationTests/InsertTests.cs +++ b/tests/IntegrationTests/InsertTests.cs @@ -475,6 +475,130 @@ public void InsertMySqlDecimalAsDecimal(bool prepare) var val = ((decimal) reader.GetValue(0)).ToString(CultureInfo.InvariantCulture); Assert.Equal(value, val); } + + [Theory] + [InlineData(1_000_000, 1024, true)] + [InlineData(1_000_000, 1024, false)] + [InlineData(1_000_000, int.MaxValue, true)] + [InlineData(1_000_000, int.MaxValue, false)] + [InlineData(0xff_fff8, 299593, true)] + [InlineData(0xff_fff8, 299593, false)] + [InlineData(0xff_fff8, 300000, true)] + [InlineData(0xff_fff8, 300000, false)] + [InlineData(0xff_fff8, int.MaxValue, true)] + [InlineData(0xff_fff8, int.MaxValue, false)] + [InlineData(0xff_fff9, int.MaxValue, true)] + [InlineData(0xff_fff9, int.MaxValue, false)] + [InlineData(0x1ff_fff0, 299593, true)] + [InlineData(0x1ff_fff0, 299593, false)] + [InlineData(0x1ff_fff0, 300000, true)] + [InlineData(0x1ff_fff0, 300000, false)] + [InlineData(15_999_999, int.MaxValue, true)] + [InlineData(15_999_999, int.MaxValue, false)] + [InlineData(16_000_000, int.MaxValue, true)] + [InlineData(16_000_000, int.MaxValue, false)] + [InlineData(16_000_001, int.MaxValue, true)] + [InlineData(16_000_001, int.MaxValue, false)] + [InlineData(31_999_999, 999_999, true)] + [InlineData(31_999_999, 1_000_000, false)] + [InlineData(32_000_000, 1_000_001, true)] + [InlineData(32_000_000, 1_000_002, false)] + [InlineData(32_000_001, 1_000_003, true)] + [InlineData(32_000_001, 1_000_004, false)] + public async Task SendLongData(int dataLength, int chunkLength, bool isAsync) + { + using MySqlConnection connection = new MySqlConnection(AppConfig.ConnectionString); + connection.Open(); + connection.Execute(""" + drop table if exists insert_mysql_long_data; + create table insert_mysql_long_data(rowid integer not null primary key auto_increment, value longblob); + """); + + var random = new Random(dataLength); + var data = new byte[dataLength]; + random.NextBytes(data); + + using var chunkStream = new ChunkStream(data, chunkLength); + + using var writeCommand = new MySqlCommand(""" + insert into insert_mysql_long_data(value) values(@value); + select length(value) from insert_mysql_long_data order by rowid; + """, connection); + writeCommand.Parameters.AddWithValue("@value", chunkStream); + writeCommand.Prepare(); + using (var reader = isAsync ? await writeCommand.ExecuteReaderAsync().ConfigureAwait(true) : writeCommand.ExecuteReader()) + { + Assert.True(reader.Read()); + Assert.Equal(1, reader.FieldCount); + Assert.Equal(dataLength, reader.GetInt32(0)); + Assert.False(reader.Read()); + } + + using var readCommand = new MySqlCommand("select value from insert_mysql_long_data order by rowid;", connection); + using (var reader = readCommand.ExecuteReader()) + { + Assert.True(reader.Read()); + var readData = (byte[]) reader.GetValue(0); + Assert.True(data.AsSpan().SequenceEqual(readData)); // much faster than Assert.Equal + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendLongDataMultipleStatements(bool isAsync) + { + using MySqlConnection connection = new MySqlConnection(AppConfig.ConnectionString); + connection.Open(); + connection.Execute(""" + drop table if exists insert_mysql_long_data; + create table insert_mysql_long_data(rowid integer not null primary key auto_increment, value longblob); + """); + + var data1 = new byte[1000]; + var data2 = new byte[2000]; + var data3 = new byte[3000]; + var random = new Random(1); + random.NextBytes(data1); + random.NextBytes(data2); + random.NextBytes(data3); + + using var chunkStream1 = new ChunkStream(data1, int.MaxValue); + using var chunkStream2 = new ChunkStream(data2, int.MaxValue); + using var chunkStream3 = new ChunkStream(data3, int.MaxValue); + + using var writeCommand = new MySqlCommand(""" + insert into insert_mysql_long_data(rowid, value) values(1, @value1); + insert into insert_mysql_long_data(rowid, value) values(2, @value2); + insert into insert_mysql_long_data(rowid, value) values(3, @value3); + """, connection); + writeCommand.Parameters.AddWithValue("@value1", chunkStream1); + writeCommand.Parameters.AddWithValue("@value2", chunkStream2); + writeCommand.Parameters.AddWithValue("@value3", chunkStream3); + writeCommand.Prepare(); + if (isAsync) + await writeCommand.ExecuteNonQueryAsync(); + else + writeCommand.ExecuteNonQuery(); + + using var readCommand = new MySqlCommand("select value from insert_mysql_long_data order by rowid;", connection); + using (var reader = readCommand.ExecuteReader()) + { + Assert.True(reader.Read()); + var readData = (byte[]) reader.GetValue(0); + Assert.True(data1.AsSpan().SequenceEqual(readData)); + + Assert.True(reader.Read()); + readData = (byte[]) reader.GetValue(0); + Assert.True(data2.AsSpan().SequenceEqual(readData)); + + Assert.True(reader.Read()); + readData = (byte[]) reader.GetValue(0); + Assert.True(data3.AsSpan().SequenceEqual(readData)); + + Assert.False(reader.Read()); + } + } #endif [Theory]