Skip to content

Commit 73e3110

Browse files
committed
Wrap IOException in MySqlException. Fixes #388
1 parent 635aa9b commit 73e3110

File tree

4 files changed

+77
-52
lines changed

4 files changed

+77
-52
lines changed

src/MySqlConnector/Core/ServerSession.cs

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -193,71 +193,78 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
193193

194194
public async Task ConnectAsync(ConnectionSettings cs, ILoadBalancer loadBalancer, IOBehavior ioBehavior, CancellationToken cancellationToken)
195195
{
196-
lock (m_lock)
197-
{
198-
VerifyState(State.Created);
199-
m_state = State.Connecting;
200-
}
201-
var connected = false;
202-
if (cs.ConnectionType == ConnectionType.Tcp)
203-
connected = await OpenTcpSocketAsync(cs, loadBalancer, ioBehavior, cancellationToken).ConfigureAwait(false);
204-
else if (cs.ConnectionType == ConnectionType.Unix)
205-
connected = await OpenUnixSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
206-
if (!connected)
196+
try
207197
{
208198
lock (m_lock)
209-
m_state = State.Failed;
210-
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");
211-
}
199+
{
200+
VerifyState(State.Created);
201+
m_state = State.Connecting;
202+
}
203+
var connected = false;
204+
if (cs.ConnectionType == ConnectionType.Tcp)
205+
connected = await OpenTcpSocketAsync(cs, loadBalancer, ioBehavior, cancellationToken).ConfigureAwait(false);
206+
else if (cs.ConnectionType == ConnectionType.Unix)
207+
connected = await OpenUnixSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
208+
if (!connected)
209+
{
210+
lock (m_lock)
211+
m_state = State.Failed;
212+
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");
213+
}
212214

213-
var byteHandler = new SocketByteHandler(m_socket);
214-
m_payloadHandler = new StandardPayloadHandler(byteHandler);
215+
var byteHandler = new SocketByteHandler(m_socket);
216+
m_payloadHandler = new StandardPayloadHandler(byteHandler);
215217

216-
var payload = await ReceiveAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
217-
var initialHandshake = InitialHandshakePayload.Create(payload);
218+
var payload = await ReceiveAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
219+
var initialHandshake = InitialHandshakePayload.Create(payload);
218220

219-
// if PluginAuth is supported, then use the specified auth plugin; else, fall back to protocol capabilities to determine the auth type to use
220-
string authPluginName;
221-
if ((initialHandshake.ProtocolCapabilities & ProtocolCapabilities.PluginAuth) != 0)
222-
authPluginName = initialHandshake.AuthPluginName;
223-
else
224-
authPluginName = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.SecureConnection) == 0 ? "mysql_old_password" : "mysql_native_password";
225-
if (authPluginName != "mysql_native_password" && authPluginName != "sha256_password" && authPluginName != "caching_sha2_password")
226-
throw new NotSupportedException("Authentication method '{0}' is not supported.".FormatInvariant(initialHandshake.AuthPluginName));
221+
// if PluginAuth is supported, then use the specified auth plugin; else, fall back to protocol capabilities to determine the auth type to use
222+
string authPluginName;
223+
if ((initialHandshake.ProtocolCapabilities & ProtocolCapabilities.PluginAuth) != 0)
224+
authPluginName = initialHandshake.AuthPluginName;
225+
else
226+
authPluginName = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.SecureConnection) == 0 ? "mysql_old_password" : "mysql_native_password";
227+
if (authPluginName != "mysql_native_password" && authPluginName != "sha256_password" && authPluginName != "caching_sha2_password")
228+
throw new NotSupportedException("Authentication method '{0}' is not supported.".FormatInvariant(initialHandshake.AuthPluginName));
227229

228-
ServerVersion = new ServerVersion(Encoding.ASCII.GetString(initialHandshake.ServerVersion));
229-
ConnectionId = initialHandshake.ConnectionId;
230-
AuthPluginData = initialHandshake.AuthPluginData;
231-
m_useCompression = cs.UseCompression && (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Compress) != 0;
230+
ServerVersion = new ServerVersion(Encoding.ASCII.GetString(initialHandshake.ServerVersion));
231+
ConnectionId = initialHandshake.ConnectionId;
232+
AuthPluginData = initialHandshake.AuthPluginData;
233+
m_useCompression = cs.UseCompression && (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Compress) != 0;
232234

233-
var serverSupportsSsl = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Ssl) != 0;
234-
if (cs.SslMode != MySqlSslMode.None && (cs.SslMode != MySqlSslMode.Preferred || serverSupportsSsl))
235-
{
236-
if (!serverSupportsSsl)
237-
throw new MySqlException("Server does not support SSL");
238-
await InitSslAsync(initialHandshake.ProtocolCapabilities, cs, ioBehavior, cancellationToken).ConfigureAwait(false);
239-
}
235+
var serverSupportsSsl = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Ssl) != 0;
236+
if (cs.SslMode != MySqlSslMode.None && (cs.SslMode != MySqlSslMode.Preferred || serverSupportsSsl))
237+
{
238+
if (!serverSupportsSsl)
239+
throw new MySqlException("Server does not support SSL");
240+
await InitSslAsync(initialHandshake.ProtocolCapabilities, cs, ioBehavior, cancellationToken).ConfigureAwait(false);
241+
}
240242

241-
m_supportsConnectionAttributes = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.ConnectionAttributes) != 0;
242-
if (m_supportsConnectionAttributes && s_connectionAttributes == null)
243-
s_connectionAttributes = CreateConnectionAttributes();
243+
m_supportsConnectionAttributes = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.ConnectionAttributes) != 0;
244+
if (m_supportsConnectionAttributes && s_connectionAttributes == null)
245+
s_connectionAttributes = CreateConnectionAttributes();
244246

245-
m_supportsDeprecateEof = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0;
247+
m_supportsDeprecateEof = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0;
246248

247-
payload = HandshakeResponse41Payload.Create(initialHandshake, cs, m_useCompression, m_supportsConnectionAttributes ? s_connectionAttributes : null);
248-
await SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
249-
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
249+
payload = HandshakeResponse41Payload.Create(initialHandshake, cs, m_useCompression, m_supportsConnectionAttributes ? s_connectionAttributes : null);
250+
await SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false);
251+
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
250252

251-
// if server doesn't support the authentication fast path, it will send a new challenge
252-
if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature)
253-
{
254-
payload = await SwitchAuthenticationAsync(cs, payload, ioBehavior, cancellationToken).ConfigureAwait(false);
255-
}
253+
// if server doesn't support the authentication fast path, it will send a new challenge
254+
if (payload.HeaderByte == AuthenticationMethodSwitchRequestPayload.Signature)
255+
{
256+
payload = await SwitchAuthenticationAsync(cs, payload, ioBehavior, cancellationToken).ConfigureAwait(false);
257+
}
256258

257-
OkPayload.Create(payload);
259+
OkPayload.Create(payload);
258260

259-
if (m_useCompression)
260-
m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler);
261+
if (m_useCompression)
262+
m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler);
263+
}
264+
catch (IOException ex)
265+
{
266+
throw new MySqlException("Couldn't connect to server", ex);
267+
}
261268
}
262269

263270
public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)

tests/MySqlConnector.Tests/ConnectionTests.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,16 @@ public void AuthPluginNameNotNullTerminated()
142142
}
143143
}
144144

145+
[Fact]
146+
public void IncompleteServerHandshake()
147+
{
148+
m_server.SendIncompletePostHandshakeResponse = true;
149+
using (var connection = new MySqlConnection(m_csb.ConnectionString))
150+
{
151+
Assert.Throws<MySqlException>(() => connection.Open());
152+
}
153+
}
154+
145155
private static async Task WaitForConditionAsync<T>(T expected, Func<T> getValue)
146156
{
147157
var sw = Stopwatch.StartNew();

tests/MySqlConnector.Tests/FakeMySqlServer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public void Stop()
4646
public string ServerVersion { get; set; } = "5.7.10-test";
4747

4848
public bool SuppressAuthPluginNameTerminatingNull { get; set; }
49+
public bool SendIncompletePostHandshakeResponse { get; set; }
4950

5051
internal void ClientDisconnected() => Interlocked.Decrement(ref m_activeConnections);
5152

tests/MySqlConnector.Tests/FakeMySqlServerConnection.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ public async Task RunAsync(TcpClient client, CancellationToken token)
2828
{
2929
await SendAsync(stream, 0, WriteInitialHandshake);
3030
await ReadPayloadAsync(stream, token); // handshake response
31+
32+
if (m_server.SendIncompletePostHandshakeResponse)
33+
{
34+
await stream.WriteAsync(new byte[] { 1, 0, 0, 2 }, 0, 4);
35+
return;
36+
}
37+
3138
await SendAsync(stream, 2, WriteOk);
3239

3340
var keepRunning = true;

0 commit comments

Comments
 (0)