diff --git a/.ci/config/config.compression+ssl.json b/.ci/config/config.compression+ssl.json index 8a11c95bf..71cd817a0 100644 --- a/.ci/config/config.compression+ssl.json +++ b/.ci/config/config.compression+ssl.json @@ -4,7 +4,7 @@ "SocketPath": "./../../../../.ci/run/mysql/mysqld.sock", "PasswordlessUser": "no_password", "SecondaryDatabase": "testdb2", - "UnsupportedFeatures": "RsaEncryption,CachingSha2Password,Tls12,Tls13,UuidToBin", + "UnsupportedFeatures": "CachingSha2Password,Redirection,RsaEncryption,Tls12,Tls13,UuidToBin", "MySqlBulkLoaderLocalCsvFile": "../../../TestData/LoadData_UTF8_BOM_Unix.CSV", "MySqlBulkLoaderLocalTsvFile": "../../../TestData/LoadData_UTF8_BOM_Unix.TSV", "CertificatesPath": "../../../../.ci/server/certs" diff --git a/.ci/config/config.compression.json b/.ci/config/config.compression.json index f42f53ab6..bf1073a12 100644 --- a/.ci/config/config.compression.json +++ b/.ci/config/config.compression.json @@ -4,7 +4,7 @@ "SocketPath": "./../../../../.ci/run/mysql/mysqld.sock", "PasswordlessUser": "no_password", "SecondaryDatabase": "testdb2", - "UnsupportedFeatures": "Ed25519,QueryAttributes,StreamingResults,Tls11,UnixDomainSocket,ZeroDateTime", + "UnsupportedFeatures": "Ed25519,QueryAttributes,Redirection,StreamingResults,Tls11,UnixDomainSocket,ZeroDateTime", "MySqlBulkLoaderLocalCsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.CSV", "MySqlBulkLoaderLocalTsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.TSV" } diff --git a/.ci/config/config.json b/.ci/config/config.json index 183b2299c..035c05855 100644 --- a/.ci/config/config.json +++ b/.ci/config/config.json @@ -4,7 +4,7 @@ "SocketPath": "./../../../../.ci/run/mysql/mysqld.sock", "PasswordlessUser": "no_password", "SecondaryDatabase": "testdb2", - "UnsupportedFeatures": "Ed25519,QueryAttributes,StreamingResults,Tls11,UnixDomainSocket,ZeroDateTime", + "UnsupportedFeatures": "Ed25519,QueryAttributes,Redirection,StreamingResults,Tls11,UnixDomainSocket,ZeroDateTime", "MySqlBulkLoaderLocalCsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.CSV", "MySqlBulkLoaderLocalTsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.TSV" } diff --git a/.ci/config/config.ssl.json b/.ci/config/config.ssl.json index 84261b1be..705e0a168 100644 --- a/.ci/config/config.ssl.json +++ b/.ci/config/config.ssl.json @@ -4,7 +4,7 @@ "SocketPath": "./../../../../.ci/run/mysql/mysqld.sock", "PasswordlessUser": "no_password", "SecondaryDatabase": "testdb2", - "UnsupportedFeatures": "RsaEncryption,CachingSha2Password,Tls12,Tls13,UuidToBin", + "UnsupportedFeatures": "CachingSha2Password,Redirection,RsaEncryption,Tls12,Tls13,UuidToBin", "MySqlBulkLoaderLocalCsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.CSV", "MySqlBulkLoaderLocalTsvFile": "../../../../tests/TestData/LoadData_UTF8_BOM_Unix.TSV", "CertificatesPath": "../../../../.ci/server/certs" diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e14ef7d82..47aa77a39 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -136,7 +136,7 @@ jobs: arguments: '-c Release --no-restore' testRunTitle: ${{ format('{0}, $(Agent.OS), {1}, {2}', 'mysql:8.0', 'net472/net8.0', 'No SSL') }} env: - DATA__UNSUPPORTEDFEATURES: 'Ed25519,QueryAttributes,StreamingResults,Tls11,UnixDomainSocket' + DATA__UNSUPPORTEDFEATURES: 'Ed25519,QueryAttributes,Redirection,StreamingResults,Tls11,UnixDomainSocket' DATA__CONNECTIONSTRING: 'server=localhost;port=3306;user id=mysqltest;password=test;database=mysqltest;ssl mode=none;DefaultCommandTimeout=3600;AllowPublicKeyRetrieval=True;UseCompression=True' - job: windows_integration_tests_2 @@ -174,7 +174,7 @@ jobs: arguments: '-c Release --no-restore' testRunTitle: ${{ format('{0}, $(Agent.OS), {1}, {2}', 'mysql:8.0', 'net6.0', 'No SSL') }} env: - DATA__UNSUPPORTEDFEATURES: 'Ed25519,QueryAttributes,StreamingResults,Tls11,UnixDomainSocket' + DATA__UNSUPPORTEDFEATURES: 'Ed25519,QueryAttributes,Redirection,StreamingResults,Tls11,UnixDomainSocket' DATA__CONNECTIONSTRING: 'server=localhost;port=3306;user id=mysqltest;password=test;database=mysqltest;ssl mode=none;DefaultCommandTimeout=3600;AllowPublicKeyRetrieval=True' - job: linux_integration_tests @@ -187,27 +187,27 @@ jobs: 'MySQL 8.0': image: 'mysql:8.0' connectionStringExtra: 'AllowPublicKeyRetrieval=True' - unsupportedFeatures: 'Ed25519,StreamingResults,Tls11,ZeroDateTime' + unsupportedFeatures: 'Ed25519,Redirection,StreamingResults,Tls11,ZeroDateTime' 'MySQL 8.4': image: 'mysql:8.4' connectionStringExtra: 'AllowPublicKeyRetrieval=True' - unsupportedFeatures: 'Ed25519,StreamingResults,Tls11,ZeroDateTime' + unsupportedFeatures: 'Ed25519,Redirection,StreamingResults,Tls11,ZeroDateTime' 'MySQL 9.0': image: 'mysql:9.0' connectionStringExtra: 'AllowPublicKeyRetrieval=True' - unsupportedFeatures: 'Ed25519,StreamingResults,Tls11,ZeroDateTime' + unsupportedFeatures: 'Ed25519,Redirection,StreamingResults,Tls11,ZeroDateTime' 'MariaDB 10.6': image: 'mariadb:10.6' connectionStringExtra: '' - unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin' + unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin,Redirection' 'MariaDB 10.11': image: 'mariadb:10.11' connectionStringExtra: '' - unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin' + unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin,Redirection' 'MariaDB 11.4': image: 'mariadb:11.4' connectionStringExtra: '' - unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin' + unsupportedFeatures: 'CachingSha2Password,CancelSleepSuccessfully,Json,RoundDateTime,QueryAttributes,Sha256Password,Tls11,UuidToBin,Redirection' steps: - template: '.ci/integration-tests-steps.yml' parameters: diff --git a/src/MySqlConnector/Core/ConnectionPool.cs b/src/MySqlConnector/Core/ConnectionPool.cs index 17227b631..a11391cd5 100644 --- a/src/MySqlConnector/Core/ConnectionPool.cs +++ b/src/MySqlConnector/Core/ConnectionPool.cs @@ -8,10 +8,16 @@ namespace MySqlConnector.Core; -internal sealed class ConnectionPool : IDisposable +internal sealed class ConnectionPool : IConnectionPoolMetadata, IDisposable { public int Id { get; } + ConnectionPool? IConnectionPoolMetadata.ConnectionPool => this; + + int IConnectionPoolMetadata.Generation => m_generation; + + int IConnectionPoolMetadata.GetNewSessionId() => Interlocked.Increment(ref m_lastSessionId); + public string? Name { get; } public ConnectionSettings ConnectionSettings { get; } @@ -95,6 +101,7 @@ public async ValueTask GetSessionAsync(MySqlConnection connection m_leasedSessions.Add(session.Id, session); leasedSessionsCountPooled = m_leasedSessions.Count; } + MetricsReporter.AddUsed(this); ActivitySourceHelper.CopyTags(session.ActivityTags, activity); Log.ReturningPooledSession(m_logger, Id, session.Id, leasedSessionsCountPooled); @@ -106,7 +113,8 @@ public async ValueTask GetSessionAsync(MySqlConnection connection } // create a new session - session = await ConnectSessionAsync(connection, s_createdNewSession, startingTimestamp, activity, ioBehavior, cancellationToken).ConfigureAwait(false); + session = await ServerSession.ConnectAndRedirectAsync(m_connectionLogger, m_logger, this, ConnectionSettings, m_loadBalancer, + connection, s_createdNewSession, startingTimestamp, activity, ioBehavior, cancellationToken).ConfigureAwait(false); AdjustHostConnectionCount(session, 1); session.OwningConnection = new(connection); int leasedSessionsCountNew; @@ -402,7 +410,8 @@ private async Task CreateMinimumPooledSessions(MySqlConnection connection, IOBeh try { - var session = await ConnectSessionAsync(connection, s_createdToReachMinimumPoolSize, Stopwatch.GetTimestamp(), null, ioBehavior, cancellationToken).ConfigureAwait(false); + var session = await ServerSession.ConnectAndRedirectAsync(m_connectionLogger, m_logger, this, ConnectionSettings, m_loadBalancer, + connection, s_createdToReachMinimumPoolSize, Stopwatch.GetTimestamp(), null, ioBehavior, cancellationToken).ConfigureAwait(false); AdjustHostConnectionCount(session, 1); lock (m_sessions) _ = m_sessions.AddFirst(session); @@ -416,81 +425,6 @@ private async Task CreateMinimumPooledSessions(MySqlConnection connection, IOBeh } } - private async ValueTask ConnectSessionAsync(MySqlConnection connection, Action logMessage, long startingTimestamp, Activity? activity, IOBehavior ioBehavior, CancellationToken cancellationToken) - { - var session = new ServerSession(m_connectionLogger, this, m_generation, Interlocked.Increment(ref m_lastSessionId)); - if (m_logger.IsEnabled(LogLevel.Debug)) - logMessage(m_logger, Id, session.Id, null); - string? statusInfo; - try - { - statusInfo = await session.ConnectAsync(ConnectionSettings, connection, startingTimestamp, m_loadBalancer, activity, ioBehavior, cancellationToken).ConfigureAwait(false); - } - catch (Exception) - { - await session.DisposeAsync(ioBehavior, default).ConfigureAwait(false); - throw; - } - - Exception? redirectionException = null; - if (statusInfo is not null && statusInfo.StartsWith("Location: mysql://", StringComparison.Ordinal)) - { - // server redirection string has the format "Location: mysql://{host}:{port}/user={userId}[&ttl={ttl}]" - Log.HasServerRedirectionHeader(m_logger, session.Id, statusInfo); - - if (ConnectionSettings.ServerRedirectionMode == MySqlServerRedirectionMode.Disabled) - { - Log.ServerRedirectionIsDisabled(m_logger, Id); - } - else if (Utility.TryParseRedirectionHeader(statusInfo, out var host, out var port, out var user)) - { - if (host != ConnectionSettings.HostNames![0] || port != ConnectionSettings.Port || user != ConnectionSettings.UserID) - { - var redirectedSettings = ConnectionSettings.CloneWith(host, port, user); - Log.OpeningNewConnection(m_logger, Id, host, port, user); - var redirectedSession = new ServerSession(m_connectionLogger, this, m_generation, Interlocked.Increment(ref m_lastSessionId)); - try - { - _ = await redirectedSession.ConnectAsync(redirectedSettings, connection, startingTimestamp, m_loadBalancer, activity, ioBehavior, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - Log.FailedToConnectRedirectedSession(m_logger, ex, Id, redirectedSession.Id); - redirectionException = ex; - } - - if (redirectionException is null) - { - Log.ClosingSessionToUseRedirectedSession(m_logger, Id, session.Id, redirectedSession.Id); - await session.DisposeAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - return redirectedSession; - } - else - { - try - { - await redirectedSession.DisposeAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - } - catch (Exception) - { - } - } - } - else - { - Log.SessionAlreadyConnectedToServer(m_logger, session.Id); - } - } - } - - if (ConnectionSettings.ServerRedirectionMode == MySqlServerRedirectionMode.Required) - { - Log.RequiresServerRedirection(m_logger, Id); - throw new MySqlException(MySqlErrorCode.UnableToConnectToHost, "Server does not support redirection", redirectionException); - } - return session; - } - public static ConnectionPool? CreatePool(string connectionString, MySqlConnectorLoggingConfiguration loggingConfiguration, string? name) { // parse connection string and check for 'Pooling' setting; return 'null' if pooling is disabled diff --git a/src/MySqlConnector/Core/ConnectionSettings.cs b/src/MySqlConnector/Core/ConnectionSettings.cs index c93e16004..8b365833c 100644 --- a/src/MySqlConnector/Core/ConnectionSettings.cs +++ b/src/MySqlConnector/Core/ConnectionSettings.cs @@ -270,8 +270,12 @@ public int ConnectionTimeoutMilliseconds private ConnectionSettings(ConnectionSettings other, string host, int port, string userId) { - ConnectionStringBuilder = other.ConnectionStringBuilder; - ConnectionString = other.ConnectionString; + ConnectionStringBuilder = new MySqlConnectionStringBuilder(other.ConnectionString); + ConnectionStringBuilder.Port = (uint)port; + ConnectionStringBuilder.Server = host; + ConnectionStringBuilder.UserID = userId; + + ConnectionString = ConnectionStringBuilder.ConnectionString; ConnectionProtocol = MySqlConnectionProtocol.Sockets; HostNames = [host]; diff --git a/src/MySqlConnector/Core/IConnectionPoolMetadata.cs b/src/MySqlConnector/Core/IConnectionPoolMetadata.cs new file mode 100644 index 000000000..64f8dd31b --- /dev/null +++ b/src/MySqlConnector/Core/IConnectionPoolMetadata.cs @@ -0,0 +1,26 @@ +namespace MySqlConnector.Core; + +internal interface IConnectionPoolMetadata +{ + /// + /// Returns the this is associated with, + /// or null if it represents a non-pooled connection. + /// + ConnectionPool? ConnectionPool { get; } + + /// + /// Returns the ID of the connection pool, or 0 if this is a non-pooled connection. + /// + int Id { get; } + + /// + /// Returns the generation of the connection pool, or 0 if this is a non-pooled connection. + /// + int Generation { get; } + + /// + /// Returns a new session ID. + /// + /// A new session ID. + int GetNewSessionId(); +} diff --git a/src/MySqlConnector/Core/NonPooledConnectionPoolMetadata.cs b/src/MySqlConnector/Core/NonPooledConnectionPoolMetadata.cs new file mode 100644 index 000000000..20d296722 --- /dev/null +++ b/src/MySqlConnector/Core/NonPooledConnectionPoolMetadata.cs @@ -0,0 +1,13 @@ +namespace MySqlConnector.Core; + +internal sealed class NonPooledConnectionPoolMetadata : IConnectionPoolMetadata +{ + public static IConnectionPoolMetadata Instance { get; } = new NonPooledConnectionPoolMetadata(); + + public ConnectionPool? ConnectionPool => null; + public int Id => 0; + public int Generation => 0; + public int GetNewSessionId() => Interlocked.Increment(ref m_lastId); + + private int m_lastId; +} diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index dcaf1e69f..dcd64d009 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -25,21 +25,16 @@ namespace MySqlConnector.Core; internal sealed partial class ServerSession : IServerCapabilities { - public ServerSession(ILogger logger) - : this(logger, null, 0, Interlocked.Increment(ref s_lastId)) - { - } - - public ServerSession(ILogger logger, ConnectionPool? pool, int poolGeneration, int id) + public ServerSession(ILogger logger, IConnectionPoolMetadata pool) { m_logger = logger; m_lock = new(); m_payloadCache = new(); - Id = (pool?.Id ?? 0) + "." + id; + Id = pool.Id + "." + pool.GetNewSessionId(); ServerVersion = ServerVersion.Empty; CreatedTimestamp = Stopwatch.GetTimestamp(); - Pool = pool; - PoolGeneration = poolGeneration; + Pool = pool.ConnectionPool; + PoolGeneration = pool.Generation; HostName = ""; m_activityTags = []; DataReader = new(); @@ -391,7 +386,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella m_state = State.Closed; } - public async Task ConnectAsync(ConnectionSettings cs, MySqlConnection connection, long startingTimestamp, ILoadBalancer? loadBalancer, Activity? activity, IOBehavior ioBehavior, CancellationToken cancellationToken) + private async Task ConnectAsync(ConnectionSettings cs, MySqlConnection connection, long startingTimestamp, ILoadBalancer? loadBalancer, Activity? activity, IOBehavior ioBehavior, CancellationToken cancellationToken) { try { @@ -533,7 +528,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } var ok = OkPayload.Create(payload.Span, this); - var statusInfo = ok.StatusInfo; + var redirectionUrl = ok.RedirectionUrl; if (m_useCompression) m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler); @@ -558,7 +553,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout; - return statusInfo; + return redirectionUrl; } catch (ArgumentException ex) { @@ -572,6 +567,75 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } } + public static async ValueTask ConnectAndRedirectAsync(ILogger connectionLogger, ILogger poolLogger, IConnectionPoolMetadata pool, ConnectionSettings cs, ILoadBalancer? loadBalancer, MySqlConnection connection, Action? logMessage, long startingTimestamp, Activity? activity, IOBehavior ioBehavior, CancellationToken cancellationToken) + { + var session = new ServerSession(connectionLogger, pool); + if (logMessage is not null && poolLogger.IsEnabled(LogLevel.Debug)) + logMessage(poolLogger, pool.Id, session.Id, null); + + string? redirectionUrl; + try + { + redirectionUrl = await session.ConnectAsync(cs, connection, startingTimestamp, loadBalancer, activity, ioBehavior, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + await session.DisposeAsync(ioBehavior, default).ConfigureAwait(false); + throw; + } + + Exception? redirectionException = null; + if (redirectionUrl is not null) + { + Log.HasServerRedirectionHeader(connectionLogger, session.Id, redirectionUrl); + if (cs.ServerRedirectionMode == MySqlServerRedirectionMode.Disabled) + { + Log.ServerRedirectionIsDisabled(connectionLogger, session.Id); + return session; + } + + if (Utility.TryParseRedirectionHeader(redirectionUrl, cs.UserID, out var host, out var port, out var user)) + { + if (host != cs.HostNames![0] || port != cs.Port || user != cs.UserID) + { + var redirectedSettings = cs.CloneWith(host, port, user); + Log.OpeningNewConnection(connectionLogger, session.Id, host, port, user); + var redirectedSession = new ServerSession(connectionLogger, pool); + try + { + await redirectedSession.ConnectAsync(redirectedSettings, connection, startingTimestamp, loadBalancer, activity, ioBehavior, cancellationToken).ConfigureAwait(false); + Log.ClosingSessionToUseRedirectedSession(connectionLogger, session.Id, redirectedSession.Id); + await session.DisposeAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + return redirectedSession; + } + catch (Exception ex) + { + redirectionException = ex; + Log.FailedToConnectRedirectedSession(connectionLogger, ex, session.Id, redirectedSession.Id); + try + { + await redirectedSession.DisposeAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + } + } + } + else + { + Log.SessionAlreadyConnectedToServer(connectionLogger, session.Id); + } + } + } + + if (cs.ServerRedirectionMode == MySqlServerRedirectionMode.Required) + { + Log.RequiresServerRedirection(connectionLogger, session.Id); + throw new MySqlException(MySqlErrorCode.UnableToConnectToHost, "Server does not support redirection", redirectionException); + } + return session; + } + public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConnection connection, IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyState(State.Connected); @@ -1922,7 +1986,6 @@ protected override void OnStatementBegin(int index) private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @\uE001MySqlConnector\uE001Sleep;"u8); private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8); private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8); - private static int s_lastId; private readonly ILogger m_logger; private readonly object m_lock; diff --git a/src/MySqlConnector/Logging/Log.cs b/src/MySqlConnector/Logging/Log.cs index 7cb3bd432..57b10a088 100644 --- a/src/MySqlConnector/Logging/Log.cs +++ b/src/MySqlConnector/Logging/Log.cs @@ -405,23 +405,23 @@ internal static partial class Log [LoggerMessage(EventIds.HasServerRedirectionHeader, LogLevel.Trace, "Session {SessionId} has server redirection header {Header}")] public static partial void HasServerRedirectionHeader(ILogger logger, string sessionId, string header); - [LoggerMessage(EventIds.ServerRedirectionIsDisabled, LogLevel.Trace, "Pool {PoolId} server redirection is disabled; ignoring redirection")] - public static partial void ServerRedirectionIsDisabled(ILogger logger, int poolId); + [LoggerMessage(EventIds.ServerRedirectionIsDisabled, LogLevel.Trace, "Session {SessionId} server redirection is disabled; ignoring redirection")] + public static partial void ServerRedirectionIsDisabled(ILogger logger, string sessionId); - [LoggerMessage(EventIds.OpeningNewConnection, LogLevel.Debug, "Pool {PoolId} opening new connection to {Host}:{Port} as {User}")] - public static partial void OpeningNewConnection(ILogger logger, int poolId, string host, int port, string user); + [LoggerMessage(EventIds.OpeningNewConnection, LogLevel.Debug, "Session {SessionId} opening new connection to {Host}:{Port} as {User}")] + public static partial void OpeningNewConnection(ILogger logger, string sessionId, string host, int port, string user); - [LoggerMessage(EventIds.FailedToConnectRedirectedSession, LogLevel.Information, "Pool {PoolId} failed to connect redirected session {SessionId}")] - public static partial void FailedToConnectRedirectedSession(ILogger logger, Exception ex, int poolId, string sessionId); + [LoggerMessage(EventIds.FailedToConnectRedirectedSession, LogLevel.Information, "Session {SessionId} failed to connect redirected session {RedirectedSessionId}")] + public static partial void FailedToConnectRedirectedSession(ILogger logger, Exception ex, string sessionId, string redirectedSessionId); - [LoggerMessage(EventIds.ClosingSessionToUseRedirectedSession, LogLevel.Trace, "Pool {PoolId} closing session {SessionId} to use redirected session {RedirectedSessionId} instead")] - public static partial void ClosingSessionToUseRedirectedSession(ILogger logger, int poolId, string sessionId, string redirectedSessionId); + [LoggerMessage(EventIds.ClosingSessionToUseRedirectedSession, LogLevel.Trace, "Closing session {SessionId} to use redirected session {RedirectedSessionId} instead")] + public static partial void ClosingSessionToUseRedirectedSession(ILogger logger, string sessionId, string redirectedSessionId); [LoggerMessage(EventIds.SessionAlreadyConnectedToServer, LogLevel.Trace, "Session {SessionId} is already connected to this server; ignoring redirection")] public static partial void SessionAlreadyConnectedToServer(ILogger logger, string sessionId); - [LoggerMessage(EventIds.RequiresServerRedirection, LogLevel.Error, "Pool {PoolId} requires server redirection but server doesn't support it")] - public static partial void RequiresServerRedirection(ILogger logger, int poolId); + [LoggerMessage(EventIds.RequiresServerRedirection, LogLevel.Error, "Session {SessionId} requires server redirection but server doesn't support it")] + public static partial void RequiresServerRedirection(ILogger logger, string sessionId); [LoggerMessage(EventIds.CreatedPoolWillNotBeUsed, LogLevel.Debug, "Pool {PoolId} was created but will not be used (due to race)")] public static partial void CreatedPoolWillNotBeUsed(ILogger logger, int poolId); diff --git a/src/MySqlConnector/MySqlConnection.cs b/src/MySqlConnector/MySqlConnection.cs index 7a835d202..571af0e51 100644 --- a/src/MySqlConnector/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlConnection.cs @@ -3,6 +3,7 @@ #if NET6_0_OR_GREATER using System.Globalization; #endif +using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; @@ -1062,22 +1063,10 @@ private async ValueTask CreateSessionAsync(ConnectionPool? pool, // only "fail over" and "random" load balancers supported without connection pooling var loadBalancer = connectionSettings.LoadBalance == MySqlLoadBalance.Random && connectionSettings.HostNames!.Count > 1 ? RandomLoadBalancer.Instance : FailOverLoadBalancer.Instance; - - var session = new ServerSession(m_logger) - { - OwningConnection = new WeakReference(this), - }; + var session = await ServerSession.ConnectAndRedirectAsync(m_logger, m_logger, NonPooledConnectionPoolMetadata.Instance, connectionSettings, loadBalancer, this, null, startingTimestamp, null, actualIOBehavior, connectToken).ConfigureAwait(false); + session.OwningConnection = new WeakReference(this); Log.CreatedNonPooledSession(m_logger, session.Id); - try - { - _ = await session.ConnectAsync(connectionSettings, this, startingTimestamp, loadBalancer, activity, actualIOBehavior, connectToken).ConfigureAwait(false); - return session; - } - catch (Exception) - { - await session.DisposeAsync(actualIOBehavior, default).ConfigureAwait(false); - throw; - } + return session; } } catch (OperationCanceledException) when (timeoutSource?.IsCancellationRequested is true) @@ -1114,6 +1103,8 @@ private async ValueTask CreateSessionAsync(ConnectionPool? pool, internal SslProtocols SslProtocol => m_session!.SslProtocol; + internal IPEndPoint? SessionEndPoint => m_session!.IPEndPoint; + internal void SetState(ConnectionState newState) { if (m_connectionState != newState) diff --git a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs index a08a3195d..37db8862f 100644 --- a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs +++ b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs @@ -16,6 +16,7 @@ internal sealed class OkPayload public string? NewSchema { get; } public CharacterSet? NewCharacterSet { get; } public int? NewConnectionId { get; } + public string? RedirectionUrl { get; } public const byte Signature = 0x00; @@ -64,6 +65,7 @@ public static void Verify(ReadOnlySpan span, IServerCapabilities serverCap CharacterSet clientCharacterSet = default; CharacterSet connectionCharacterSet = default; CharacterSet resultsCharacterSet = default; + string? redirectionUrl = default; int? connectionId = null; ReadOnlySpan statusBytes; @@ -115,6 +117,13 @@ public static void Verify(ReadOnlySpan span, IServerCapabilities serverCap { connectionId = Utf8Parser.TryParse(systemVariableValue, out int parsedConnectionId, out var bytesConsumed) && bytesConsumed == systemVariableValue.Length ? parsedConnectionId : default(int?); } + else if (systemVariableName.SequenceEqual("redirect_url"u8)) + { + if (systemVariableValue.Length > 0) + { + redirectionUrl = Encoding.UTF8.GetString(systemVariableValue); + } + } } while (reader.Offset < systemVariablesEndOffset); break; @@ -150,7 +159,7 @@ public static void Verify(ReadOnlySpan span, IServerCapabilities serverCap clientCharacterSet == CharacterSet.Utf8Mb3Binary && connectionCharacterSet == CharacterSet.Utf8Mb3Binary && resultsCharacterSet == CharacterSet.Utf8Mb3Binary ? CharacterSet.Utf8Mb3Binary : CharacterSet.None; - if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is CharacterSet.None && connectionId is null) + if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null && clientCharacterSet is CharacterSet.None && connectionId is null && redirectionUrl is null) { if (serverStatus == ServerStatus.AutoCommit) return s_autoCommitOk; @@ -158,7 +167,7 @@ public static void Verify(ReadOnlySpan span, IServerCapabilities serverCap return s_autoCommitSessionStateChangedOk; } - return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, characterSet, connectionId); + return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema, characterSet, connectionId, redirectionUrl); } else { @@ -166,7 +175,7 @@ public static void Verify(ReadOnlySpan span, IServerCapabilities serverCap } } - private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, CharacterSet newCharacterSet, int? connectionId) + private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema, CharacterSet newCharacterSet, int? connectionId, string? redirectionUrl) { AffectedRowCount = affectedRowCount; LastInsertId = lastInsertId; @@ -176,8 +185,9 @@ private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serve NewSchema = newSchema; NewCharacterSet = newCharacterSet; NewConnectionId = connectionId; + RedirectionUrl = redirectionUrl; } - private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default); - private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, default, default, default, default); + private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, default, default, default, default, default); + private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, default, default, default, default, default); } diff --git a/src/MySqlConnector/Utilities/Utility.cs b/src/MySqlConnector/Utilities/Utility.cs index d1d3ce02b..63f07c8c6 100644 --- a/src/MySqlConnector/Utilities/Utility.cs +++ b/src/MySqlConnector/Utilities/Utility.cs @@ -336,68 +336,45 @@ public static void Resize([NotNull] ref ResizableArray? resizableArray, in resizableArray.DoResize(newLength); } - public static bool TryParseRedirectionHeader(string header, out string host, out int port, out string user) + public static bool TryParseRedirectionHeader(string redirectUrl, string initialUser, out string host, out int port, out string user) { host = ""; port = 0; user = ""; - if (!header.StartsWith("Location: mysql://", StringComparison.Ordinal) || header.Length < 22) + // "mariadb/mysql://[{user}[:{password}]@]{host}[:{port}]/[{db}[?{opt1}={value1}[&{opt2}={value2}]]]']" + if (!redirectUrl.StartsWith("mysql://", StringComparison.Ordinal) && !redirectUrl.StartsWith("mariadb://", StringComparison.Ordinal)) return false; - bool isCommunityFormat; - int portIndex; - if (header[18] == '[') - { - // Community protocol: - // Location: mysql://[redirectedHostName]:redirectedPort/?user=redirectedUser&ttl=%d\n - isCommunityFormat = true; - - var hostIndex = 19; - var closeSquareBracketIndex = header.IndexOf(']', hostIndex); - if (closeSquareBracketIndex == -1) - return false; - - host = header[hostIndex..closeSquareBracketIndex]; - if (header.Length <= closeSquareBracketIndex + 2) - return false; - if (header[closeSquareBracketIndex + 1] != ':') - return false; - portIndex = closeSquareBracketIndex + 2; - } - else + try { - // Azure protocol: - // Location: mysql://redirectedHostName:redirectedPort/user=redirectedUser&ttl=%d (where ttl is optional) - isCommunityFormat = false; - - var hostIndex = 18; - var colonIndex = header.IndexOf(':', hostIndex); - if (colonIndex == -1) - return false; + var uri = new Uri(redirectUrl); + host = uri.Host; + if (string.IsNullOrEmpty(host)) return false; + if (host.StartsWith('[') && host.EndsWith("]", StringComparison.InvariantCulture)) host = host.Substring(1, host.Length - 2); + + port = uri.Port; + user = Uri.UnescapeDataString(uri.UserInfo.Split(':')[0]); + if (string.IsNullOrEmpty(user) && !string.IsNullOrEmpty(uri.Query)) + { + // query format "?{opt1}={value1}[&{opt2}={value2}]" + var q = uri.Query.Substring(1); + foreach (var token in q.Split('&')) + { + if (token.StartsWith("user=", StringComparison.InvariantCulture)) + { + user = Uri.UnescapeDataString(token.Substring(5)); + } + } + } - host = header[hostIndex..colonIndex]; - portIndex = colonIndex + 1; + if (string.IsNullOrEmpty(user)) user = initialUser; + return true; } - - var userIndex = header.IndexOf(isCommunityFormat ? "/?user=" : "/user=", StringComparison.Ordinal); - if (userIndex == -1) - return false; - -#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP2_1_OR_GREATER - if (!int.TryParse(header.AsSpan(portIndex, userIndex - portIndex), out port) || port <= 0) -#else - if (!int.TryParse(header[portIndex..userIndex], out port) || port <= 0) -#endif + catch (UriFormatException) + { return false; - - userIndex += isCommunityFormat ? 7 : 6; - var ampersandIndex = header.IndexOf('&', userIndex); - var newlineIndex = header.IndexOf('\n', userIndex); - var terminatorIndex = ampersandIndex == -1 ? (newlineIndex == -1 ? header.Length : newlineIndex) : - (newlineIndex == -1 ? ampersandIndex : Math.Min(ampersandIndex, newlineIndex)); - user = header[userIndex..terminatorIndex]; - return user.Length != 0; + } } public static TimeSpan ParseTimeSpan(ReadOnlySpan value) diff --git a/tests/IntegrationTests/RedirectionTests.cs b/tests/IntegrationTests/RedirectionTests.cs new file mode 100644 index 000000000..892f1eb7f --- /dev/null +++ b/tests/IntegrationTests/RedirectionTests.cs @@ -0,0 +1,194 @@ +#if !MYSQL_DATA +using System.Globalization; +using System.Net; +using System.Net.Sockets; + +namespace IntegrationTests; + +public class RedirectionTests : IClassFixture, IDisposable +{ + public RedirectionTests(DatabaseFixture database) + { + m_database = database; + m_database.Connection.Open(); + } + + public void Dispose() + { + m_database.Connection.Close(); + } + + [SkippableFact(ServerFeatures.Redirection)] + public void RedirectionTest() + { + StartProxy(); + + // wait for proxy to launch + Thread.Sleep(50); + var csb = AppConfig.CreateConnectionStringBuilder(); + var initialServer = csb.Server; + var initialPort = csb.Port; + m_database.Connection.Execute($"set @@global.redirect_url=\"mariadb://{initialServer}:{initialPort}\""); + + try + { + // changing to proxy port + csb.Server = "localhost"; + csb.Port = (uint)proxy.ListenPort; + csb.ServerRedirectionMode = MySqlServerRedirectionMode.Preferred; + + // ensure that connection has been redirected + using (var db = new MySqlConnection(csb.ConnectionString)) + { + db.Open(); + using (var cmd = db.CreateCommand()) + { + cmd.CommandText = "SELECT 1"; + cmd.ExecuteNonQuery(); + } + + Assert.Equal((int) initialPort, db.SessionEndPoint!.Port); + db.Close(); + } + + // ensure that connection has been redirected with Required + csb.ServerRedirectionMode = MySqlServerRedirectionMode.Required; + using (var db = new MySqlConnection(csb.ConnectionString)) + { + db.Open(); + using (var cmd = db.CreateCommand()) + { + cmd.CommandText = "SELECT 1"; + cmd.ExecuteNonQuery(); + } + + Assert.Equal((int) initialPort, db.SessionEndPoint!.Port); + db.Close(); + } + + // ensure that redirection is not done + csb.ServerRedirectionMode = MySqlServerRedirectionMode.Disabled; + using (var db = new MySqlConnection(csb.ConnectionString)) + { + db.Open(); + using (var cmd = db.CreateCommand()) + { + cmd.CommandText = "SELECT 1"; + cmd.ExecuteNonQuery(); + } + + Assert.Equal(proxy.ListenPort, db.SessionEndPoint!.Port); + db.Close(); + } + + } finally{ + m_database.Connection.Execute( + $"set @@global.redirect_url=\"\""); + } + MySqlConnection.ClearAllPools(); + // ensure that when required, throwing error if no redirection + csb.ServerRedirectionMode = MySqlServerRedirectionMode.Required; + using (var db = new MySqlConnection(csb.ConnectionString)) + { + try + { + db.Open(); + Assert.Fail("must have thrown error"); + } + catch (MySqlException ex) + { + Assert.Equal((int) MySqlErrorCode.UnableToConnectToHost, ex.Number); + } + } + + StopProxy(); + } + + protected void StartProxy() + { + var csb = AppConfig.CreateConnectionStringBuilder(); + proxy = new ServerConfiguration( csb.Server, (int)csb.Port ); + Thread serverThread = new Thread( ServerThread ); + serverThread.Start( proxy ); + } + + protected void StopProxy() + { + proxy.RunServer = false; + proxy.ServerSocket.Close(); + } + + private class ServerConfiguration { + + public IPAddress RemoteAddress; + public int RemotePort; + public int ListenPort; + public Socket ServerSocket; + public ServerConfiguration(String remoteAddress, int remotePort) { + var ipHostEntry = Dns.GetHostEntry(remoteAddress); + RemoteAddress = ipHostEntry.AddressList[0]; + RemotePort = remotePort; + ListenPort = 0; + } + public bool RunServer = true; + } + + private static void ServerThread(Object configObj) { + ServerConfiguration config = (ServerConfiguration)configObj; + Socket serverSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + serverSocket.Bind( new IPEndPoint( IPAddress.Any, 0 ) ); + serverSocket.Listen(1); + config.ListenPort = ((IPEndPoint) serverSocket.LocalEndPoint).Port; + config.ServerSocket = serverSocket; + while( config.RunServer ) { + try + { + Socket client = serverSocket.Accept(); + Thread clientThread = new Thread(ClientThread); + clientThread.Start(new ClientContext() { Config = config, Client = client }); + } + catch (SocketException) when (!config.RunServer) + { + return; + } + } + } + + private class ClientContext { + public ServerConfiguration Config; + public Socket Client; + } + + private static void ClientThread(Object contextObj) { + ClientContext context = (ClientContext)contextObj; + Socket client = context.Client; + ServerConfiguration config = context.Config; + IPEndPoint remoteEndPoint = new IPEndPoint( config.RemoteAddress, config.RemotePort ); + Socket remote = new Socket( remoteEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + remote.Connect( remoteEndPoint ); + Byte[] buffer = new Byte[4096]; + for(;;) { + if (!config.RunServer) + { + remote.Close(); + client.Close(); + return; + } + if( client.Available > 0 ) { + var count = client.Receive( buffer ); + if( count == 0 ) return; + remote.Send( buffer, count, SocketFlags.None ); + } + if( remote.Available > 0 ) { + var count = remote.Receive( buffer ); + if( count == 0 ) return; + client.Send( buffer, count, SocketFlags.None ); + } + } + } + + readonly DatabaseFixture m_database; + private ServerConfiguration proxy; +} +#endif diff --git a/tests/IntegrationTests/ServerFeatures.cs b/tests/IntegrationTests/ServerFeatures.cs index cefe5d377..ac4fa4863 100644 --- a/tests/IntegrationTests/ServerFeatures.cs +++ b/tests/IntegrationTests/ServerFeatures.cs @@ -35,4 +35,9 @@ public enum ServerFeatures /// A "SLEEP" command produces a result set when it is cancelled, not an error payload. /// CancelSleepSuccessfully = 0x40_0000, + + /// + /// Server permit redirection, available on first OK_Packet + /// + Redirection = 0x80_0000, } diff --git a/tests/MySqlConnector.Tests/UtilityTests.cs b/tests/MySqlConnector.Tests/UtilityTests.cs index 57cdcc1b1..8040c2284 100644 --- a/tests/MySqlConnector.Tests/UtilityTests.cs +++ b/tests/MySqlConnector.Tests/UtilityTests.cs @@ -8,21 +8,17 @@ namespace MySqlConnector.Tests; public class UtilityTests { [Theory] - [InlineData("Location: mysql://host.example.com:1234/user=user@host", "host.example.com", 1234, "user@host")] - [InlineData("Location: mysql://host.example.com:1234/user=user@host\n", "host.example.com", 1234, "user@host")] - [InlineData("Location: mysql://host.example.com:1234/user=user@host&ttl=60", "host.example.com", 1234, "user@host")] - [InlineData("Location: mysql://host.example.com:1234/user=user@host&ttl=60\n", "host.example.com", 1234, "user@host")] - [InlineData("Location: mysql://[host.example.com]:1234/?user=abcd", "host.example.com", 1234, "abcd")] - [InlineData("Location: mysql://[host.example.com]:1234/?user=abcd\n", "host.example.com", 1234, "abcd")] - [InlineData("Location: mysql://[host.example.com]:1234/?user=abcd&ttl=60", "host.example.com", 1234, "abcd")] - [InlineData("Location: mysql://[host.example.com]:1234/?user=abcd&ttl=60\n", "host.example.com", 1234, "abcd")] - [InlineData("Location: mysql://[2001:4860:4860::8888]:1234/?user=abcd", "2001:4860:4860::8888", 1234, "abcd")] - [InlineData("Location: mysql://[2001:4860:4860::8888]:1234/?user=abcd\n", "2001:4860:4860::8888", 1234, "abcd")] - [InlineData("Location: mysql://[2001:4860:4860::8888]:1234/?user=abcd&ttl=60", "2001:4860:4860::8888", 1234, "abcd")] - [InlineData("Location: mysql://[2001:4860:4860::8888]:1234/?user=abcd&ttl=60\n", "2001:4860:4860::8888", 1234, "abcd")] + [InlineData("mariadb://host.example.com:1234/?user=user@host", "host.example.com", 1234, "user@host")] + [InlineData("mariadb://user%40host:password@host.example.com:1234/", "host.example.com", 1234, "user@host")] + [InlineData("mariadb://host.example.com:1234/?user=user@host&ttl=60", "host.example.com", 1234, "user@host")] + [InlineData("mariadb://someuser:password@host.example.com:1234/?user=user@host&ttl=60\n", "host.example.com", 1234, "someuser")] + [InlineData("mysql://[2001:4860:4860::8888]:1234/?user=abcd", "2001:4860:4860::8888", 1234, "abcd")] + [InlineData("mysql://[2001:4860:4860::8888]:1234/?user=abcd\n", "2001:4860:4860::8888", 1234, "abcd")] + [InlineData("mysql://[2001:4860:4860::8888]:1234/?user=abcd&ttl=60", "2001:4860:4860::8888", 1234, "abcd")] + [InlineData("mysql://[2001:4860:4860::8888]:1234/?user=abcd&ttl=60\n", "2001:4860:4860::8888", 1234, "abcd")] public void ParseRedirectionHeader(string input, string expectedHost, int expectedPort, string expectedUser) { - Assert.True(Utility.TryParseRedirectionHeader(input, out var host, out var port, out var user)); + Assert.True(Utility.TryParseRedirectionHeader(input, null, out var host, out var port, out var user)); Assert.Equal(expectedHost, host); Assert.Equal(expectedPort, port); Assert.Equal(expectedUser, user); @@ -30,26 +26,14 @@ public void ParseRedirectionHeader(string input, string expectedHost, int expect [Theory] [InlineData("")] - [InlineData("Location: mysql")] - [InlineData("Location: mysql://host.example.com")] - [InlineData("Location: mysql://host.example.com:")] - [InlineData("Location: mysql://[host.example.com")] - [InlineData("Location: mysql://[host.example.com]")] - [InlineData("Location: mysql://[host.example.com]:")] - [InlineData("Location: mysql://host.example.com:123")] - [InlineData("Location: mysql://host.example.com:123/")] - [InlineData("Location: mysql://[host.example.com]:123")] - [InlineData("Location: mysql://[host.example.com]:123/")] - [InlineData("Location: mysql://host.example.com:/user=")] - [InlineData("Location: mysql://host.example.com:123/user=")] - [InlineData("Location: mysql://[host.example.com]:123/?user=")] - [InlineData("Location: mysql://host.example.com:/user=user@host")] - [InlineData("Location: mysql://host.example.com:-1/user=user@host")] - [InlineData("Location: mysql://host.example.com:0/user=user@host")] - [InlineData("Location: mysql://[host.example.com]:123/user=abcd")] + [InlineData("not formated")] + [InlineData("mysql")] + [InlineData("mysql://[host.example.com")] + [InlineData("mysql://host.example.com:-1/user=user@host")] + [InlineData("mysql://[host.example.com]:123/user=abcd")] public void ParseRedirectionHeaderFails(string input) { - Assert.False(Utility.TryParseRedirectionHeader(input, out _, out _, out _)); + Assert.False(Utility.TryParseRedirectionHeader(input, null, out _, out _, out _)); } [Theory]