Skip to content

Commit 52bc3b2

Browse files
committed
Add IValuesEnumerator abstraction.
This allows MySqlBulkCopy.WriteToServer to operate on IEnumerable<DataRow>.
1 parent 8e9ea27 commit 52bc3b2

File tree

5 files changed

+180
-28
lines changed

5 files changed

+180
-28
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using System.Collections.Generic;
2+
using System.Data;
3+
using System.Data.Common;
4+
using System.Linq;
5+
using System.Threading.Tasks;
6+
7+
namespace MySqlConnector.Core
8+
{
9+
/// <summary>
10+
/// <see cref="IValuesEnumerator"/> provides an abstraction over iterating through a sequence of
11+
/// rows, where each row can fill an array of field values.
12+
/// </summary>
13+
internal interface IValuesEnumerator
14+
{
15+
int FieldCount { get; }
16+
ValueTask<bool> MoveNextAsync();
17+
bool MoveNext();
18+
void GetValues(object?[] values);
19+
}
20+
21+
internal sealed class DbDataReaderValuesEnumerator : IValuesEnumerator
22+
{
23+
public DbDataReaderValuesEnumerator(DbDataReader dataReader) => m_dataReader = dataReader;
24+
25+
public int FieldCount => m_dataReader.FieldCount;
26+
27+
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(m_dataReader.ReadAsync());
28+
29+
public bool MoveNext() => m_dataReader.Read();
30+
31+
public void GetValues(object?[] values) => m_dataReader.GetValues(values);
32+
33+
readonly DbDataReader m_dataReader;
34+
}
35+
36+
internal sealed class DataReaderValuesEnumerator : IValuesEnumerator
37+
{
38+
public static IValuesEnumerator Create(IDataReader dataReader) => dataReader is DbDataReader dbDataReader ? (IValuesEnumerator) new DbDataReaderValuesEnumerator(dbDataReader) : new DataReaderValuesEnumerator(dataReader);
39+
40+
public DataReaderValuesEnumerator(IDataReader dataReader) => m_dataReader = dataReader;
41+
42+
public int FieldCount => m_dataReader.FieldCount;
43+
44+
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(MoveNext());
45+
46+
public bool MoveNext() => m_dataReader.Read();
47+
48+
public void GetValues(object?[] values) => m_dataReader.GetValues(values);
49+
50+
readonly IDataReader m_dataReader;
51+
}
52+
53+
#if !NETSTANDARD1_3
54+
internal sealed class DataRowsValuesEnumerator : IValuesEnumerator
55+
{
56+
public static IValuesEnumerator Create(DataTable dataTable) => new DataRowsValuesEnumerator(dataTable.Rows.Cast<DataRow>(), dataTable.Columns.Count);
57+
58+
public DataRowsValuesEnumerator(IEnumerable<DataRow> dataRows, int columnCount)
59+
{
60+
m_dataRows = dataRows.GetEnumerator();
61+
FieldCount = columnCount;
62+
}
63+
64+
public int FieldCount { get; }
65+
66+
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(MoveNext());
67+
68+
public bool MoveNext()
69+
{
70+
if (m_dataRows.MoveNext())
71+
return true;
72+
m_dataRows.Dispose();
73+
return false;
74+
}
75+
76+
public void GetValues(object?[] values)
77+
{
78+
var row = m_dataRows.Current;
79+
for (var i = 0; i < FieldCount; i++)
80+
values[i] = row[i];
81+
}
82+
83+
readonly IEnumerator<DataRow> m_dataRows;
84+
}
85+
#endif
86+
}

src/MySqlConnector/Core/ResultSet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior)
9999
}
100100
break;
101101

102-
case IDataReader dataReader:
103-
await MySqlBulkCopy.SendDataReaderAsync(Connection, dataReader, ioBehavior, CancellationToken.None).ConfigureAwait(false);
102+
case IValuesEnumerator valuesEnumerator:
103+
await MySqlBulkCopy.SendDataReaderAsync(Connection, valuesEnumerator, ioBehavior, CancellationToken.None).ConfigureAwait(false);
104104
break;
105105

106106
default:

src/MySqlConnector/MySql.Data.MySqlClient/MySqlBulkCopy.cs

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using System;
22
using System.Buffers;
33
using System.Buffers.Text;
4+
using System.Collections.Generic;
45
using System.Data;
5-
using System.Data.Common;
66
using System.Text;
77
using System.Threading;
88
using System.Threading.Tasks;
99
using MySql.Data.Types;
10+
using MySqlConnector.Core;
1011
using MySqlConnector.Protocol;
1112
using MySqlConnector.Protocol.Serialization;
1213
using MySqlConnector.Utilities;
@@ -26,39 +27,33 @@ public MySqlBulkCopy(MySqlConnection connection, MySqlTransaction? transaction =
2627
public string? DestinationTableName { get; set; }
2728

2829
#if !NETSTANDARD1_3
29-
public void WriteToServer(DataTable dataTable) => WriteToServerAsync(dataTable, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
30+
public void WriteToServer(DataTable dataTable) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
3031

3132
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
32-
public Task WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(dataTable, IOBehavior.Asynchronous, cancellationToken).AsTask();
33+
public Task WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None).AsTask();
3334
#else
34-
public ValueTask WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(dataTable, IOBehavior.Asynchronous, cancellationToken);
35+
public ValueTask WriteToServerAsync(DataTable dataTable, CancellationToken cancellationToken = default) => WriteToServerAsync(DataRowsValuesEnumerator.Create(dataTable), IOBehavior.Synchronous, CancellationToken.None);
3536
#endif
3637

38+
public void WriteToServer(IEnumerable<DataRow> dataRows, int columnCount) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
3739
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
38-
private async ValueTask<int> WriteToServerAsync(DataTable dataTable, IOBehavior ioBehavior, CancellationToken cancellationToken)
40+
public Task WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Asynchronous, cancellationToken).AsTask();
3941
#else
40-
private async ValueTask WriteToServerAsync(DataTable dataTable, IOBehavior ioBehavior, CancellationToken cancellationToken)
42+
public ValueTask WriteToServerAsync(IEnumerable<DataRow> dataRows, int columnCount, CancellationToken cancellationToken = default) => WriteToServerAsync(new DataRowsValuesEnumerator(dataRows, columnCount), IOBehavior.Asynchronous, cancellationToken);
4143
#endif
42-
{
43-
using var reader = dataTable.CreateDataReader();
44-
await WriteToServerAsync(reader, ioBehavior, cancellationToken);
45-
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
46-
return 0;
47-
#endif
48-
}
4944
#endif
5045

51-
public void WriteToServer(IDataReader dataReader) => WriteToServerAsync(dataReader, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
46+
public void WriteToServer(IDataReader dataReader) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult();
5247
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
53-
public Task WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(dataReader, IOBehavior.Asynchronous, cancellationToken).AsTask();
48+
public Task WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Asynchronous, cancellationToken).AsTask();
5449
#else
55-
public ValueTask WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(dataReader, IOBehavior.Asynchronous, cancellationToken);
50+
public ValueTask WriteToServerAsync(IDataReader dataReader, CancellationToken cancellationToken = default) => WriteToServerAsync(DataReaderValuesEnumerator.Create(dataReader), IOBehavior.Asynchronous, cancellationToken);
5651
#endif
5752

5853
#if !NETSTANDARD2_1 && !NETCOREAPP3_0
59-
private async ValueTask<int> WriteToServerAsync(IDataReader dataReader, IOBehavior ioBehavior, CancellationToken cancellationToken)
54+
private async ValueTask<int> WriteToServerAsync(IValuesEnumerator values, IOBehavior ioBehavior, CancellationToken cancellationToken)
6055
#else
61-
private async ValueTask WriteToServerAsync(IDataReader dataReader, IOBehavior ioBehavior, CancellationToken cancellationToken)
56+
private async ValueTask WriteToServerAsync(IValuesEnumerator values, IOBehavior ioBehavior, CancellationToken cancellationToken)
6257
#endif
6358
{
6459
var tableName = DestinationTableName ?? throw new InvalidOperationException("DestinationTableName must be set before calling WriteToServer");
@@ -73,7 +68,7 @@ private async ValueTask WriteToServerAsync(IDataReader dataReader, IOBehavior io
7368
LineTerminator = "\n",
7469
Local = true,
7570
NumberOfLinesToSkip = 0,
76-
Source = dataReader ?? throw new ArgumentNullException(nameof(dataReader)),
71+
Source = values ?? throw new ArgumentNullException(nameof(values)),
7772
TableName = tableName,
7873
Timeout = BulkCopyTimeout,
7974
};
@@ -129,26 +124,25 @@ private async ValueTask WriteToServerAsync(IDataReader dataReader, IOBehavior io
129124
static string QuoteIdentifier(string identifier) => "`" + identifier.Replace("`", "``") + "`";
130125
}
131126

132-
internal static async Task SendDataReaderAsync(MySqlConnection connection, IDataReader dataReader, IOBehavior ioBehavior, CancellationToken cancellationToken)
127+
internal static async Task SendDataReaderAsync(MySqlConnection connection, IValuesEnumerator valuesEnumerator, IOBehavior ioBehavior, CancellationToken cancellationToken)
133128
{
134129
// rent a buffer that can fit in one packet
135130
const int maxLength = 16_777_200;
136131
var buffer = ArrayPool<byte>.Shared.Rent(maxLength + 1);
137132
var outputIndex = 0;
138-
var dbDataReader = dataReader as DbDataReader;
139133

140134
try
141135
{
142-
var values = new object?[dataReader.FieldCount];
136+
var values = new object?[valuesEnumerator.FieldCount];
143137
while (true)
144138
{
145-
var hasMore = ioBehavior == IOBehavior.Asynchronous && dbDataReader is object ?
146-
await dbDataReader.ReadAsync(cancellationToken).ConfigureAwait(false) :
147-
dataReader.Read();
139+
var hasMore = ioBehavior == IOBehavior.Asynchronous ?
140+
await valuesEnumerator.MoveNextAsync().ConfigureAwait(false) :
141+
valuesEnumerator.MoveNext();
148142
if (!hasMore)
149143
break;
150144

151-
dataReader.GetValues(values);
145+
valuesEnumerator.GetValues(values);
152146
retryRow:
153147
var startOutputIndex = outputIndex;
154148
var wroteRow = true;

tests/SideBySide/BulkLoaderAsync.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,42 @@ public async Task BulkLoadLocalMemoryStream()
410410
Assert.Equal(5, rowCount);
411411
}
412412

413+
[Fact]
414+
public async Task BulkLoadDataReader()
415+
{
416+
using var connection = new MySqlConnection(GetLocalConnectionString());
417+
using var connection2 = new MySqlConnection(GetLocalConnectionString());
418+
await connection.OpenAsync();
419+
await connection2.OpenAsync();
420+
using (var cmd = new MySqlCommand(@"drop table if exists bulk_load_data_reader_source;
421+
drop table if exists bulk_load_data_reader_destination;
422+
create table bulk_load_data_reader_source(value int, name text);
423+
create table bulk_load_data_reader_destination(value int, name text);
424+
insert into bulk_load_data_reader_source values(0, 'zero'),(1,'one'),(2,'two'),(3,'three'),(4,'four'),(5,'five'),(6,'six');", connection))
425+
{
426+
await cmd.ExecuteNonQueryAsync();
427+
}
428+
429+
using (var cmd = new MySqlCommand("select * from bulk_load_data_reader_source;", connection))
430+
using (var reader = await cmd.ExecuteReaderAsync())
431+
{
432+
var bulkCopy = new MySqlBulkCopy(connection2) { DestinationTableName = "bulk_load_data_reader_destination", };
433+
await bulkCopy.WriteToServerAsync(reader);
434+
}
435+
436+
using var cmd1 = new MySqlCommand("select * from bulk_load_data_reader_source order by value;", connection);
437+
using var cmd2 = new MySqlCommand("select * from bulk_load_data_reader_destination order by value;", connection2);
438+
using var reader1 = await cmd1.ExecuteReaderAsync();
439+
using var reader2 = await cmd2.ExecuteReaderAsync();
440+
while (await reader1.ReadAsync())
441+
{
442+
Assert.True(await reader2.ReadAsync());
443+
Assert.Equal(reader1.GetInt32(0), reader2.GetInt32(0));
444+
Assert.Equal(reader1.GetString(1), reader2.GetString(1));
445+
}
446+
Assert.False(await reader2.ReadAsync());
447+
}
448+
413449
#if !NETCOREAPP1_1_2
414450
[Fact]
415451
public async Task BulkLoadDataTableWithLongData()

tests/SideBySide/BulkLoaderSync.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,42 @@ public void BulkLoadLocalMemoryStream()
492492
Assert.Equal(5, rowCount);
493493
}
494494

495+
[Fact]
496+
public void BulkLoadDataReader()
497+
{
498+
using var connection = new MySqlConnection(GetLocalConnectionString());
499+
using var connection2 = new MySqlConnection(GetLocalConnectionString());
500+
connection.Open();
501+
connection2.Open();
502+
using (var cmd = new MySqlCommand(@"drop table if exists bulk_load_data_reader_source;
503+
drop table if exists bulk_load_data_reader_destination;
504+
create table bulk_load_data_reader_source(value int, name text);
505+
create table bulk_load_data_reader_destination(value int, name text);
506+
insert into bulk_load_data_reader_source values(0, 'zero'),(1,'one'),(2,'two'),(3,'three'),(4,'four'),(5,'five'),(6,'six');", connection))
507+
{
508+
cmd.ExecuteNonQuery();
509+
}
510+
511+
using (var cmd = new MySqlCommand("select * from bulk_load_data_reader_source;", connection))
512+
using (var reader = cmd.ExecuteReader())
513+
{
514+
var bulkCopy = new MySqlBulkCopy(connection2) { DestinationTableName = "bulk_load_data_reader_destination", };
515+
bulkCopy.WriteToServer(reader);
516+
}
517+
518+
using var cmd1 = new MySqlCommand("select * from bulk_load_data_reader_source order by value;", connection);
519+
using var cmd2 = new MySqlCommand("select * from bulk_load_data_reader_destination order by value;", connection2);
520+
using var reader1 = cmd1.ExecuteReader();
521+
using var reader2 = cmd2.ExecuteReader();
522+
while (reader1.Read())
523+
{
524+
Assert.True(reader2.Read());
525+
Assert.Equal(reader1.GetInt32(0), reader2.GetInt32(0));
526+
Assert.Equal(reader1.GetString(1), reader2.GetString(1));
527+
}
528+
Assert.False(reader2.Read());
529+
}
530+
495531
#if !NETCOREAPP1_1_2
496532
[Fact]
497533
public void BulkLoadDataTableWithLongData()

0 commit comments

Comments
 (0)