Skip to content

Commit 4638f17

Browse files
committed
Support multiple prepared statements in one command.
Signed-off-by: Bradley Grainger <[email protected]>
1 parent 8894090 commit 4638f17

File tree

6 files changed

+97
-66
lines changed

6 files changed

+97
-66
lines changed

src/MySqlConnector/Core/CommandExecutor.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ public static async ValueTask<MySqlDataReader> ExecuteReaderAsync(CommandListPos
3939
}
4040
}
4141

42+
await payloadCreator.WritePrologueAsync(connection, commandListPosition, ioBehavior, cancellationToken).ConfigureAwait(false);
43+
4244
var writer = new ByteBufferWriter();
4345
//// cachedProcedures will be non-null if there is a stored procedure, which is also the only time it will be read
4446
if (!payloadCreator.WriteQueryCommand(ref commandListPosition, cachedProcedures!, writer, false))

src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ internal sealed class ConcatenatedCommandPayloadCreator : ICommandPayloadCreator
88
{
99
public static ICommandPayloadCreator Instance { get; } = new ConcatenatedCommandPayloadCreator();
1010

11+
public ValueTask WritePrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken) =>
12+
throw new NotSupportedException();
13+
1114
public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDictionary<string, CachedProcedure?> cachedProcedures, ByteBufferWriter writer, bool appendSemicolon)
1215
{
1316
if (commandListPosition.CommandIndex == commandListPosition.CommandCount)

src/MySqlConnector/Core/ICommandPayloadCreator.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ namespace MySqlConnector.Core;
77
/// </summary>
88
internal interface ICommandPayloadCreator
99
{
10+
/// <summary>
11+
/// Writes any prologue data that needs to be sent before the current command in the command list.
12+
/// </summary>
13+
/// <param name="connection">The <see cref="MySqlConnection"/>.</param>
14+
/// <param name="commandListPosition">The <see cref="CommandListPosition"/> giving the current command and current prepared statement.</param>
15+
/// <param name="ioBehavior">The IO behavior.</param>
16+
/// <param name="cancellationToken">A cancellation token to cancel the asynchronous operation.</param>
17+
/// <returns>A <see cref="ValueTask"/> representing the potentially-asynchronous operation.</returns>
18+
ValueTask WritePrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken);
19+
1020
/// <summary>
1121
/// Writes the payload for an "execute query" command to <paramref name="writer"/>.
1222
/// </summary>

src/MySqlConnector/Core/SingleCommandPayloadCreator.cs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Runtime.InteropServices;
12
using MySqlConnector.Logging;
23
using MySqlConnector.Protocol;
34
using MySqlConnector.Protocol.Serialization;
@@ -12,6 +13,74 @@ internal sealed class SingleCommandPayloadCreator : ICommandPayloadCreator
1213
// with this as the first column name, the result set will be treated as 'out' parameters for the previous command.
1314
public static string OutParameterSentinelColumnName => "\uE001\b\x0B";
1415

16+
public async ValueTask WritePrologueAsync(MySqlConnection connection, CommandListPosition commandListPosition, IOBehavior ioBehavior, CancellationToken cancellationToken)
17+
{
18+
// get the current command and check for prepared statements
19+
var command = commandListPosition.CommandAt(commandListPosition.CommandIndex);
20+
var preparedStatements = commandListPosition.PreparedStatements ?? command.TryGetPreparedStatements();
21+
if (preparedStatements is not null)
22+
{
23+
// get the current prepared statement; WriteQueryCommand will advance this
24+
var preparedStatement = preparedStatements.Statements[commandListPosition.PreparedStatementIndex];
25+
if (preparedStatement.Parameters is { } parameters)
26+
{
27+
// check each parameter
28+
for (var i = 0; i < parameters.Length; i++)
29+
{
30+
// look up this parameter in the command's parameter collection and check if it is a Stream
31+
var parameterName = preparedStatement.Statement.NormalizedParameterNames![i];
32+
var parameterIndex = parameterName is not null ? (command.RawParameters?.UnsafeIndexOf(parameterName) ?? -1) : preparedStatement.Statement.ParameterIndexes[i];
33+
if (parameterIndex != -1 && command.RawParameters![parameterIndex] is { Value: Stream stream and not MemoryStream })
34+
{
35+
// send almost-full packets, but don't send exactly ProtocolUtility.MaxPacketSize bytes in one payload (as that's ambiguous to whether another packet follows).
36+
const int maxDataSize = 16_000_000;
37+
int totalBytesRead;
38+
while (true)
39+
{
40+
// write seven-byte COM_STMT_SEND_LONG_DATA header: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html
41+
var writer = new ByteBufferWriter(maxDataSize);
42+
writer.Write((byte) CommandKind.StatementSendLongData);
43+
writer.Write(preparedStatement.StatementId);
44+
writer.Write((ushort) i);
45+
46+
// keep reading from the stream until we've filled the buffer to send
47+
#if NET7_0_OR_GREATER
48+
if (ioBehavior == IOBehavior.Synchronous)
49+
totalBytesRead = stream.ReadAtLeast(writer.GetSpan(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false);
50+
else
51+
totalBytesRead = await stream.ReadAtLeastAsync(writer.GetMemory(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
52+
writer.Advance(totalBytesRead);
53+
#else
54+
totalBytesRead = 0;
55+
int bytesRead;
56+
do
57+
{
58+
var sizeToRead = maxDataSize - totalBytesRead;
59+
ReadOnlyMemory<byte> bufferMemory = writer.GetMemory(sizeToRead);
60+
if (!MemoryMarshal.TryGetArray(bufferMemory, out var arraySegment))
61+
throw new InvalidOperationException("Failed to get array segment from buffer memory.");
62+
if (ioBehavior == IOBehavior.Synchronous)
63+
bytesRead = stream.Read(arraySegment.Array!, arraySegment.Offset, sizeToRead);
64+
else
65+
bytesRead = await stream.ReadAsync(arraySegment.Array!, arraySegment.Offset, sizeToRead, cancellationToken).ConfigureAwait(false);
66+
totalBytesRead += bytesRead;
67+
writer.Advance(bytesRead);
68+
} while (bytesRead > 0);
69+
#endif
70+
71+
if (totalBytesRead == 0)
72+
break;
73+
74+
// send StatementSendLongData; MySQL Server will keep appending the sent data to the parameter value
75+
using var payload = writer.ToPayloadData();
76+
await connection.Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
77+
}
78+
}
79+
}
80+
}
81+
}
82+
}
83+
1584
public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDictionary<string, CachedProcedure?> cachedProcedures, ByteBufferWriter writer, bool appendSemicolon)
1685
{
1786
if (commandListPosition.CommandIndex == commandListPosition.CommandCount)
@@ -58,6 +127,7 @@ public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDict
58127
{
59128
commandListPosition.CommandIndex++;
60129
commandListPosition.PreparedStatementIndex = 0;
130+
commandListPosition.PreparedStatements = null;
61131
}
62132
}
63133
return true;

src/MySqlConnector/MySqlCommand.cs

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -359,68 +359,15 @@ internal async ValueTask<MySqlDataReader> ExecuteReaderAsync(CommandBehavior beh
359359
return await ExecuteReaderNoResetTimeoutAsync(behavior, ioBehavior, cancellationToken).ConfigureAwait(false);
360360
}
361361

362-
internal async ValueTask<MySqlDataReader> ExecuteReaderNoResetTimeoutAsync(CommandBehavior behavior, IOBehavior ioBehavior, CancellationToken cancellationToken)
362+
internal ValueTask<MySqlDataReader> ExecuteReaderNoResetTimeoutAsync(CommandBehavior behavior, IOBehavior ioBehavior, CancellationToken cancellationToken)
363363
{
364364
if (!IsValid(out var exception))
365-
throw exception;
366-
367-
if (((IMySqlCommand) this).TryGetPreparedStatements() is { Statements.Count: 1 } statements)
368-
{
369-
for (var i = 0; i < Parameters.Count; i++)
370-
{
371-
if (Parameters[i].Value is Stream stream and not MemoryStream)
372-
{
373-
// send almost-full packets, but don't send exactly ProtocolUtility.MaxPacketSize bytes in one payload (as that's ambiguous to whether another packet follows).
374-
const int maxDataSize = 16_000_000;
375-
int totalBytesRead;
376-
while (true)
377-
{
378-
// write seven-byte COM_STMT_SEND_LONG_DATA header: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html
379-
var writer = new ByteBufferWriter(maxDataSize);
380-
writer.Write((byte) CommandKind.StatementSendLongData);
381-
writer.Write(statements.Statements[0].StatementId);
382-
writer.Write((ushort) i);
383-
384-
// keep reading from the stream until we've filled the buffer to send
385-
#if NET7_0_OR_GREATER
386-
if (ioBehavior == IOBehavior.Synchronous)
387-
totalBytesRead = stream.ReadAtLeast(writer.GetSpan(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false);
388-
else
389-
totalBytesRead = await stream.ReadAtLeastAsync(writer.GetMemory(maxDataSize).Slice(0, maxDataSize), maxDataSize, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
390-
writer.Advance(totalBytesRead);
391-
#else
392-
totalBytesRead = 0;
393-
int bytesRead;
394-
do
395-
{
396-
var sizeToRead = maxDataSize - totalBytesRead;
397-
ReadOnlyMemory<byte> bufferMemory = writer.GetMemory(sizeToRead);
398-
if (!MemoryMarshal.TryGetArray(bufferMemory, out var arraySegment))
399-
throw new InvalidOperationException("Failed to get array segment from buffer memory.");
400-
if (ioBehavior == IOBehavior.Synchronous)
401-
bytesRead = stream.Read(arraySegment.Array!, arraySegment.Offset, sizeToRead);
402-
else
403-
bytesRead = await stream.ReadAsync(arraySegment.Array!, arraySegment.Offset, sizeToRead, cancellationToken).ConfigureAwait(false);
404-
totalBytesRead += bytesRead;
405-
writer.Advance(bytesRead);
406-
} while (bytesRead > 0);
407-
#endif
408-
409-
if (totalBytesRead == 0)
410-
break;
411-
412-
// send StatementSendLongData; MySQL Server will keep appending the sent data to the parameter value
413-
using var payload = writer.ToPayloadData();
414-
await Connection!.Session.SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
415-
}
416-
}
417-
}
418-
}
365+
return ValueTaskExtensions.FromException<MySqlDataReader>(exception);
419366

420367
var activity = NoActivity ? null : Connection!.Session.StartActivity(ActivitySourceHelper.ExecuteActivityName,
421368
ActivitySourceHelper.DatabaseStatementTagName, CommandText);
422369
m_commandBehavior = behavior;
423-
return await CommandExecutor.ExecuteReaderAsync(new(this), SingleCommandPayloadCreator.Instance, behavior, activity, ioBehavior, cancellationToken).ConfigureAwait(false);
370+
return CommandExecutor.ExecuteReaderAsync(new(this), SingleCommandPayloadCreator.Instance, behavior, activity, ioBehavior, cancellationToken);
424371
}
425372

426373
public MySqlCommand Clone() => new(this);

tests/IntegrationTests/InsertTests.cs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -520,22 +520,21 @@ public async Task SendLongData(int dataLength, int chunkLength, bool isAsync)
520520

521521
using var chunkStream = new ChunkStream(data, chunkLength);
522522

523-
using var writeCommand = new MySqlCommand("insert into insert_mysql_long_data(value) values(@value);", connection);
523+
using var writeCommand = new MySqlCommand("""
524+
insert into insert_mysql_long_data(value) values(@value);
525+
select length(value) from insert_mysql_long_data order by rowid;
526+
""", connection);
524527
writeCommand.Parameters.AddWithValue("@value", chunkStream);
525528
writeCommand.Prepare();
526-
if (isAsync)
527-
await writeCommand.ExecuteNonQueryAsync().ConfigureAwait(true);
528-
else
529-
writeCommand.ExecuteNonQuery();
530-
531-
using var readCommand = new MySqlCommand("select length(value) from insert_mysql_long_data order by rowid;", connection);
532-
using (var reader = readCommand.ExecuteReader())
529+
using (var reader = isAsync ? await writeCommand.ExecuteReaderAsync().ConfigureAwait(true) : writeCommand.ExecuteReader())
533530
{
534531
Assert.True(reader.Read());
535-
Assert.Equal(chunkStream.Length, reader.GetInt32(0));
532+
Assert.Equal(1, reader.FieldCount);
533+
Assert.Equal(dataLength, reader.GetInt32(0));
534+
Assert.False(reader.Read());
536535
}
537536

538-
readCommand.CommandText = "select value from insert_mysql_long_data order by rowid;";
537+
using var readCommand = new MySqlCommand("select value from insert_mysql_long_data order by rowid;", connection);
539538
using (var reader = readCommand.ExecuteReader())
540539
{
541540
Assert.True(reader.Read());

0 commit comments

Comments
 (0)