Skip to content

Commit 6963457

Browse files
committed
Use MySqlBulkCopy internally as the source.
This will allow more properties and events to be added to it more easily.
1 parent bb09ccf commit 6963457

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

src/MySqlConnector/Core/ResultSet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior)
9999
}
100100
break;
101101

102-
case IValuesEnumerator valuesEnumerator:
103-
await MySqlBulkCopy.SendDataReaderAsync(Connection, valuesEnumerator, ioBehavior, CancellationToken.None).ConfigureAwait(false);
102+
case MySqlBulkCopy bulkCopy:
103+
await bulkCopy.SendDataReaderAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false);
104104
break;
105105

106106
default:

src/MySqlConnector/MySql.Data.MySqlClient/MySqlBulkCopy.cs

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,69 @@ public MySqlBulkCopy(MySqlConnection connection, MySqlTransaction? transaction =
2727
public string? DestinationTableName { get; set; }
2828

2929
#if !NETSTANDARD1_3
30-
public void WriteToServer(DataTable dataTable) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
30+
public void WriteToServer(DataTable dataTable)
31+
{
32+
m_valuesEnumerator = DataRowsValuesEnumerator.Create(dataTable ?? throw new ArgumentNullException(nameof(dataTable)));
33+
WriteToServerAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
34+
}
3135

3236
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
33-
public Task WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None).AsTask();
37+
public async Task WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default)
38+
{
39+
m_valuesEnumerator = DataRowsValuesEnumerator.Create(dataTable ?? throw new ArgumentNullException(nameof(dataTable)));
40+
await WriteToServerAsync(IOBehavior.Asynchronous, CancellationToken.None).ConfigureAwait(false);
41+
}
3442
#else
35-
public ValueTask WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None);
43+
public async ValueTask WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default)
44+
{
45+
m_valuesEnumerator = DataRowsValuesEnumerator.Create(dataTable ?? throw new ArgumentNullException(nameof(dataTable)));
46+
await WriteToServerAsync(IOBehavior.Asynchronous, cancellationToken).ConfigureAwait(false);
47+
}
3648
#endif
3749

38-
public void WriteToServer(IEnumerable<DataRow> dataRows, int columnCount) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
50+
public void WriteToServer(IEnumerable<DataRow> dataRows, int columnCount)
51+
{
52+
m_valuesEnumerator = new DataRowsValuesEnumerator(dataRows ?? throw new ArgumentNullException(nameof(dataRows)), columnCount);
53+
WriteToServerAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
54+
}
3955
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
40-
public Task WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Asynchronous, cancellationToken).AsTask();
56+
public async Task WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default)
57+
{
58+
m_valuesEnumerator = new DataRowsValuesEnumerator(dataRows ?? throw new ArgumentNullException(nameof(dataRows)), columnCount);
59+
await WriteToServerAsync(IOBehavior.Asynchronous, cancellationToken).ConfigureAwait(false);
60+
}
4161
#else
42-
public ValueTask WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Asynchronous, cancellationToken);
62+
public async ValueTask WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default)
63+
{
64+
m_valuesEnumerator = new DataRowsValuesEnumerator(dataRows ?? throw new ArgumentNullException(nameof(dataRows)), columnCount);
65+
await WriteToServerAsync(IOBehavior.Asynchronous, cancellationToken).ConfigureAwait(false);
66+
}
4367
#endif
4468
#endif
4569

46-
public void WriteToServer(IDataReader dataReader) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
70+
public void WriteToServer(IDataReader dataReader)
71+
{
72+
m_valuesEnumerator = DataReaderValuesEnumerator.Create(dataReader ?? throw new ArgumentNullException(nameof(dataReader)));
73+
WriteToServerAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
74+
}
4775
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
48-
public Task WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Asynchronous, cancellationToken).AsTask();
76+
public async Task WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default)
77+
{
78+
m_valuesEnumerator = DataReaderValuesEnumerator.Create(dataReader ?? throw new ArgumentNullException(nameof(dataReader)));
79+
await WriteToServerAsync(IOBehavior.Asynchronous, cancellationToken).ConfigureAwait(false);
80+
}
4981
#else
50-
public ValueTask WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Asynchronous, cancellationToken);
82+
public async ValueTask WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default)
83+
{
84+
m_valuesEnumerator = DataReaderValuesEnumerator.Create(dataReader ?? throw new ArgumentNullException(nameof(dataReader)));
85+
await WriteToServerAsync(IOBehavior.Asynchronous, cancellationToken).ConfigureAwait(false);
86+
}
5187
#endif
5288

5389
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
54-
private async ValueTask<int> WriteToServerAsync(IValuesEnumerator values, IOBehavior ioBehavior, CancellationToken cancellationToken)
90+
private async ValueTask<int> WriteToServerAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
5591
#else
56-
private async ValueTask WriteToServerAsync(IValuesEnumerator values, IOBehavior ioBehavior, CancellationToken cancellationToken)
92+
private async ValueTask WriteToServerAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
5793
#endif
5894
{
5995
var tableName = DestinationTableName ?? throw new InvalidOperationException("DestinationTableName must be set before calling WriteToServer");
@@ -68,7 +104,7 @@ private async ValueTask WriteToServerAsync(IValuesEnumerator values, IOBehavior
68104
LineTerminator = "\n",
69105
Local = true,
70106
NumberOfLinesToSkip = 0,
71-
Source = values ?? throw new ArgumentNullException(nameof(values)),
107+
Source = this,
72108
TableName = tableName,
73109
Timeout = BulkCopyTimeout,
74110
};
@@ -124,7 +160,7 @@ private async ValueTask WriteToServerAsync(IValuesEnumerator values, IOBehavior
124160
static string QuoteIdentifier(string identifier) => "`" + identifier.Replace("`", "``") + "`";
125161
}
126162

127-
internal static async Task SendDataReaderAsync(MySqlConnection connection, IValuesEnumerator valuesEnumerator, IOBehavior ioBehavior, CancellationToken cancellationToken)
163+
internal async Task SendDataReaderAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
128164
{
129165
// rent a buffer that can fit in one packet
130166
const int maxLength = 16_777_200;
@@ -133,16 +169,16 @@ internal static async Task SendDataReaderAsync(MySqlConnection connection, IValu
133169

134170
try
135171
{
136-
var values = new object?[valuesEnumerator.FieldCount];
172+
var values = new object?[m_valuesEnumerator!.FieldCount];
137173
while (true)
138174
{
139175
var hasMore = ioBehavior == IOBehavior.Asynchronous ?
140-
await valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
141-
valuesEnumerator.MoveNext();
176+
await m_valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
177+
m_valuesEnumerator.MoveNext();
142178
if (!hasMore)
143179
break;
144180

145-
valuesEnumerator.GetValues(values);
181+
m_valuesEnumerator.GetValues(values);
146182
retryRow:
147183
var startOutputIndex = outputIndex;
148184
var wroteRow = true;
@@ -154,7 +190,7 @@ await valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
154190
else
155191
shouldAppendSeparator = true;
156192

157-
if (outputIndex >= maxLength || !WriteValue(connection, value, buffer.AsSpan(0, maxLength).Slice(outputIndex), out var bytesWritten))
193+
if (outputIndex >= maxLength || !WriteValue(m_connection, value, buffer.AsSpan(0, maxLength).Slice(outputIndex), out var bytesWritten))
158194
{
159195
wroteRow = false;
160196
break;
@@ -167,7 +203,7 @@ await valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
167203
if (startOutputIndex == 0)
168204
throw new NotSupportedException("Total row length must be less than 16MiB.");
169205
var payload = new PayloadData(new ArraySegment<byte>(buffer, 0, startOutputIndex));
170-
await connection.Session.SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
206+
await m_connection.Session.SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
171207
outputIndex = 0;
172208
goto retryRow;
173209
}
@@ -180,7 +216,7 @@ await valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
180216
if (outputIndex != 0)
181217
{
182218
var payload2 = new PayloadData(new ArraySegment<byte>(buffer, 0, outputIndex));
183-
await connection.Session.SendReplyAsync(payload2, ioBehavior, cancellationToken).ConfigureAwait(false);
219+
await m_connection.Session.SendReplyAsync(payload2, ioBehavior, cancellationToken).ConfigureAwait(false);
184220
}
185221
}
186222
finally
@@ -412,5 +448,6 @@ static bool WriteBytes(ReadOnlySpan<byte> value, Span<byte> output, out int byte
412448

413449
readonly MySqlConnection m_connection;
414450
readonly MySqlTransaction? m_transaction;
451+
IValuesEnumerator? m_valuesEnumerator;
415452
}
416453
}

0 commit comments

Comments
 (0)