Skip to content

Commit 3894332

Browse files
committed
Use FileStream in case of Local Infile.
1 parent a4f7f6c commit 3894332

File tree

6 files changed

+115
-17
lines changed

6 files changed

+115
-17
lines changed

src/MySqlConnector/MySqlClient/MySqlBulkLoader.cs

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using System.Collections.Generic;
33
using System.Data;
44
using System.IO;
@@ -153,15 +153,22 @@ private async Task<int> LoadAsync(IOBehavior ioBehavior, CancellationToken cance
153153
if (!string.IsNullOrWhiteSpace(FileName) && SourceStream != null)
154154
throw new InvalidOperationException("Cannot set both FileName and SourceStream");
155155

156-
if (string.IsNullOrWhiteSpace(FileName) && SourceStream != null)
157-
{
158-
if (!Local)
159-
throw new InvalidOperationException("Cannot use SourceStream when Local is not true.");
156+
// LOCAL INFILE case
157+
if (!string.IsNullOrWhiteSpace(FileName) && Local)
158+
{
159+
SourceStream = CreateFileStream(FileName);
160+
FileName = null;
161+
}
160162

161-
FileName = StreamPrefix + Guid.NewGuid().ToString("N");
162-
lock (s_lock)
163-
s_streams.Add(FileName, SourceStream);
164-
}
163+
if (string.IsNullOrWhiteSpace(FileName) && SourceStream != null)
164+
{
165+
if (!Local)
166+
throw new InvalidOperationException("Cannot use SourceStream when Local is not true.");
167+
168+
FileName = GenerateSourceStreamName();
169+
lock (s_lock)
170+
s_streams.Add(FileName, SourceStream);
171+
}
165172

166173
if (string.IsNullOrWhiteSpace(FileName) || string.IsNullOrWhiteSpace(TableName))
167174
{
@@ -179,7 +186,7 @@ private async Task<int> LoadAsync(IOBehavior ioBehavior, CancellationToken cance
179186
closeConnection = true;
180187
Connection.Open();
181188
}
182-
189+
183190
try
184191
{
185192
var commandString = BuildSqlCommand();
@@ -196,7 +203,24 @@ private async Task<int> LoadAsync(IOBehavior ioBehavior, CancellationToken cance
196203
}
197204
}
198205

199-
internal const string StreamPrefix = ":STREAM:";
206+
private Stream CreateFileStream(string fileName)
207+
{
208+
try
209+
{
210+
return File.OpenRead(fileName);
211+
}
212+
catch (Exception ex)
213+
{
214+
throw new MySqlException($"Could not access file \"{fileName}\"", ex);
215+
}
216+
}
217+
218+
private static string GenerateSourceStreamName()
219+
{
220+
return StreamPrefix + Guid.NewGuid().ToString("N");
221+
}
222+
223+
internal const string StreamPrefix = ":STREAM:";
200224

201225
internal static Stream GetAndRemoveStream(string streamKey)
202226
{

src/MySqlConnector/MySqlClient/MySqlConnection.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ internal async Task<CachedProcedure> GetCachedProcedure(IOBehavior ioBehavior, s
301301
internal bool TreatTinyAsBoolean => m_connectionSettings.TreatTinyAsBoolean;
302302
internal IOBehavior AsyncIOBehavior => m_connectionSettings.ForceSynchronous ? IOBehavior.Synchronous : IOBehavior.Asynchronous;
303303

304+
internal MySqlSslMode SslMode => m_connectionSettings.SslMode;
305+
304306
internal void SetActiveReader(MySqlDataReader dataReader)
305307
{
306308
if (dataReader == null)

src/MySqlConnector/MySqlClient/Results/ResultSet.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ public async Task<ResultSet> ReadResultSetHeaderAsync(IOBehavior ioBehavior)
5656
try
5757
{
5858
var localInfile = LocalInfilePayload.Create(payload);
59+
if(!IsHostVerified(Connection)
60+
&& !localInfile.FileName.StartsWith(MySqlBulkLoader.StreamPrefix, StringComparison.Ordinal))
61+
throw new NotSupportedException("Use SourceStream or SslMode >= VerifyCA for LOAD DATA LOCAL INFILE");
62+
5963
using (var stream = localInfile.FileName.StartsWith(MySqlBulkLoader.StreamPrefix, StringComparison.Ordinal) ?
6064
MySqlBulkLoader.GetAndRemoveStream(localInfile.FileName) :
6165
File.OpenRead(localInfile.FileName))
@@ -129,6 +133,12 @@ public async Task<ResultSet> ReadResultSetHeaderAsync(IOBehavior ioBehavior)
129133
return this;
130134
}
131135

136+
private bool IsHostVerified(MySqlConnection connection)
137+
{
138+
return connection.SslMode == MySqlSslMode.VerifyCA
139+
|| connection.SslMode == MySqlSslMode.VerifyFull;
140+
}
141+
132142
public async Task BufferEntireAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
133143
{
134144
while (BufferState == ResultSetState.ReadingRows || BufferState == ResultSetState.ReadResultSetHeader)

tests/SideBySide/Attributes.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ public BulkLoaderLocalCsvFileFactAttribute()
9898
if(string.IsNullOrWhiteSpace(AppConfig.MySqlBulkLoaderLocalCsvFile))
9999
Skip = "No bulk loader local CSV file specified";
100100
}
101+
102+
public bool TrustedHost
103+
{
104+
105+
get => _trustedHost;
106+
set
107+
{
108+
_trustedHost = value;
109+
110+
var csb = AppConfig.CreateConnectionStringBuilder();
111+
if (_trustedHost)
112+
{
113+
if (csb.SslMode == MySqlSslMode.None
114+
|| csb.SslMode == MySqlSslMode.Preferred
115+
|| csb.SslMode == MySqlSslMode.Required)
116+
Skip = "SslMode should be VerifyCA or higher.";
117+
}
118+
else
119+
{
120+
if (csb.SslMode == MySqlSslMode.VerifyCA
121+
|| csb.SslMode == MySqlSslMode.VerifyFull)
122+
Skip = "SslMode should be less than VerifyCA.";
123+
}
124+
}
125+
}
126+
private bool _trustedHost;
101127
}
102128

103129
public class BulkLoaderLocalTsvFileFactAttribute : FactAttribute

tests/SideBySide/LoadDataInfileAsync.cs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Data;
1+
using System.Data;
22
using System.Threading.Tasks;
33
using MySql.Data.MySqlClient;
44
using Xunit;
@@ -42,17 +42,37 @@ public async void CommandLoadCsvFile()
4242
Assert.Equal(20, rowCount);
4343
}
4444

45-
[BulkLoaderLocalCsvFileFact]
45+
[BulkLoaderLocalCsvFileFact(TrustedHost = true)]
4646
public async void CommandLoadLocalCsvFile()
4747
{
48-
string insertInlineCommand = string.Format(m_loadDataInfileCommand, " LOCAL", AppConfig.MySqlBulkLoaderLocalCsvFile.Replace("\\", "\\\\"));
48+
string insertInlineCommand = string.Format(m_loadDataInfileCommand, " LOCAL",
49+
AppConfig.MySqlBulkLoaderLocalCsvFile.Replace("\\", "\\\\"));
4950
MySqlCommand command = new MySqlCommand(insertInlineCommand, m_database.Connection);
50-
if (m_database.Connection.State != ConnectionState.Open) await m_database.Connection.OpenAsync();
51+
52+
if (m_database.Connection.State != ConnectionState.Open)
53+
await m_database.Connection.OpenAsync();
54+
5155
int rowCount = await command.ExecuteNonQueryAsync();
56+
5257
m_database.Connection.Close();
5358
Assert.Equal(20, rowCount);
5459
}
5560

61+
#if !BASELINE
62+
[BulkLoaderLocalCsvFileFact(TrustedHost = false)]
63+
public async void ThrowsNotSupportedExceptionForNotTrustedHostAndNotStream()
64+
{
65+
string insertInlineCommand = string.Format(m_loadDataInfileCommand, " LOCAL",
66+
AppConfig.MySqlBulkLoaderLocalCsvFile.Replace("\\", "\\\\"));
67+
MySqlCommand command = new MySqlCommand(insertInlineCommand, m_database.Connection);
68+
69+
if (m_database.Connection.State != ConnectionState.Open)
70+
await m_database.Connection.OpenAsync();
71+
72+
await Assert.ThrowsAsync<MySqlException>(async () => await command.ExecuteNonQueryAsync());
73+
}
74+
#endif
75+
5676
readonly DatabaseFixture m_database;
5777
readonly string m_testTable;
5878
readonly string m_loadDataInfileCommand;

tests/SideBySide/LoadDataInfileSync.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Data;
1+
using System.Data;
22
using MySql.Data.MySqlClient;
33
using Xunit;
44
using Dapper;
@@ -41,7 +41,7 @@ public void CommandLoadCsvFile()
4141
Assert.Equal(20, rowCount);
4242
}
4343

44-
[BulkLoaderLocalCsvFileFact]
44+
[BulkLoaderLocalCsvFileFact(TrustedHost = true)]
4545
public void CommandLoadLocalCsvFile()
4646
{
4747
string insertInlineCommand = string.Format(m_loadDataInfileCommand, " LOCAL", AppConfig.MySqlBulkLoaderLocalCsvFile.Replace("\\", "\\\\"));
@@ -52,6 +52,22 @@ public void CommandLoadLocalCsvFile()
5252
Assert.Equal(20, rowCount);
5353
}
5454

55+
#if !BASELINE
56+
[BulkLoaderLocalCsvFileFact(TrustedHost = false)]
57+
public void ThrowsNotSupportedExceptionForNotTrustedHostAndNotStream()
58+
{
59+
string insertInlineCommand = string.Format(m_loadDataInfileCommand, " LOCAL",
60+
AppConfig.MySqlBulkLoaderLocalCsvFile.Replace("\\", "\\\\"));
61+
MySqlCommand command = new MySqlCommand(insertInlineCommand, m_database.Connection);
62+
if (m_database.Connection.State != ConnectionState.Open)
63+
m_database.Connection.Open();
64+
65+
Assert.Throws<MySqlException>(() => command.ExecuteNonQuery());
66+
67+
m_database.Connection.Close();
68+
}
69+
#endif
70+
5571
readonly DatabaseFixture m_database;
5672
readonly string m_testTable;
5773
readonly string m_loadDataInfileCommand;

0 commit comments

Comments
 (0)