Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool)
public bool ProcAccessDenied { get; set; }
public ICollection<KeyValuePair<string, object?>> ActivityTags => m_activityTags;
public MySqlDataReader DataReader { get; set; }
public MySqlConnectionOpenedConditions Conditions { get; private set; }

public ValueTask ReturnToPoolAsync(IOBehavior ioBehavior, MySqlConnection? owningConnection)
{
Log.ReturningToPool(m_logger, Id, Pool?.Id ?? 0);
Conditions = MySqlConnectionOpenedConditions.None;
LastReturnedTimestamp = Stopwatch.GetTimestamp();
if (Pool is null)
return default;
Expand Down Expand Up @@ -414,6 +416,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
}
}

Conditions = MySqlConnectionOpenedConditions.New;
var connected = cs.ConnectionProtocol switch
{
MySqlConnectionProtocol.Sockets => await OpenTcpSocketAsync(cs, loadBalancer ?? throw new ArgumentNullException(nameof(loadBalancer)), activity, ioBehavior, cancellationToken).ConfigureAwait(false),
Expand Down Expand Up @@ -747,6 +750,7 @@ public static async ValueTask<ServerSession> ConnectAndRedirectAsync(ILogger con
public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConnection connection, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
VerifyState(State.Connected);
Conditions |= MySqlConnectionOpenedConditions.Reset;

try
{
Expand Down Expand Up @@ -829,6 +833,7 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
Log.IgnoringFailureInTryResetConnectionAsync(m_logger, ex, Id, "SocketException");
}

Conditions &= ~MySqlConnectionOpenedConditions.Reset;
return false;
}

Expand Down
20 changes: 20 additions & 0 deletions src/MySqlConnector/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella
ActivitySourceHelper.CopyTags(m_session!.ActivityTags, activity);
m_hasBeenOpened = true;
SetState(ConnectionState.Open);

if (ConnectionOpenedCallback is { } autoEnlistConnectionOpenedCallback)
{
cancellationToken.ThrowIfCancellationRequested();
await autoEnlistConnectionOpenedCallback(new(this, MySqlConnectionOpenedConditions.None), cancellationToken).ConfigureAwait(false);
}

return;
}
}
Expand Down Expand Up @@ -582,6 +589,12 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella

if (m_connectionSettings.AutoEnlist && System.Transactions.Transaction.Current is not null)
EnlistTransaction(System.Transactions.Transaction.Current);

if (ConnectionOpenedCallback is { } connectionOpenedCallback)
{
cancellationToken.ThrowIfCancellationRequested();
await connectionOpenedCallback(new(this, m_session.Conditions), cancellationToken).ConfigureAwait(false);
}
}
catch (Exception ex) when (activity is { IsAllDataRequested: true })
{
Expand Down Expand Up @@ -917,6 +930,11 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel)

using var connection = CloneWith(csb.ConnectionString);
connection.m_connectionSettings = connectionSettings;

// clear the callback because this is not intended to be a user-visible MySqlConnection that will execute setup logic; it's a
// non-pooled connection that will execute "KILL QUERY" then immediately be closed
connection.ConnectionOpenedCallback = null;

connection.Open();
#if NET6_0_OR_GREATER
var killQuerySql = string.Create(CultureInfo.InvariantCulture, $"KILL QUERY {command.Connection!.ServerThread}");
Expand Down Expand Up @@ -992,6 +1010,7 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel)
internal MySqlTransaction? CurrentTransaction { get; set; }
internal MySqlConnectorLoggingConfiguration LoggingConfiguration { get; }
internal ZstandardPlugin? ZstandardPlugin { get; set; }
internal MySqlConnectionOpenedCallback? ConnectionOpenedCallback { get; set; }
internal bool AllowLoadLocalInfile => GetInitializedConnectionSettings().AllowLoadLocalInfile;
internal bool AllowUserVariables => GetInitializedConnectionSettings().AllowUserVariables;
internal bool AllowZeroDateTime => GetInitializedConnectionSettings().AllowZeroDateTime;
Expand Down Expand Up @@ -1142,6 +1161,7 @@ private MySqlConnection(MySqlConnection other, MySqlDataSource? dataSource, stri
ProvideClientCertificatesCallback = other.ProvideClientCertificatesCallback;
ProvidePasswordCallback = other.ProvidePasswordCallback;
RemoteCertificateValidationCallback = other.RemoteCertificateValidationCallback;
ConnectionOpenedCallback = other.ConnectionOpenedCallback;
}

private void VerifyNotDisposed()
Expand Down
9 changes: 9 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedCallback.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace MySqlConnector;

/// <summary>
/// A callback that is invoked when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
/// <param name="context">A <see cref="MySqlConnectionOpenedContext"/> giving information about the connection being opened.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that can be used to cancel the asynchronous operation.</param>
/// <returns>A <see cref="ValueTask"/> representing the result of the possibly-asynchronous operation.</returns>
public delegate ValueTask MySqlConnectionOpenedCallback(MySqlConnectionOpenedContext context, CancellationToken cancellationToken);
23 changes: 23 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedConditions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace MySqlConnector;

/// <summary>
/// Bitflags giving the conditions under which a connection was opened.
/// </summary>
[Flags]
public enum MySqlConnectionOpenedConditions
{
/// <summary>
/// No specific conditions apply. This value may be used when an existing pooled connection is reused without being reset.
/// </summary>
None = 0,

/// <summary>
/// A new physical connection to a MySQL Server was opened. This value is mutually exclusive with <see cref="Reset"/>.
/// </summary>
New = 1,

/// <summary>
/// An existing pooled connection to a MySQL Server was reset. This value is mutually exclusive with <see cref="New"/>.
/// </summary>
Reset = 2,
}
23 changes: 23 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace MySqlConnector;

/// <summary>
/// Contains information passed to <see cref="MySqlConnectionOpenedCallback"/> when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
public sealed class MySqlConnectionOpenedContext
{
/// <summary>
/// The <see cref="MySqlConnection"/> that was opened.
/// </summary>
public MySqlConnection Connection { get; }

/// <summary>
/// Bitflags giving the conditions under which a connection was opened.
/// </summary>
public MySqlConnectionOpenedConditions Conditions { get; }

internal MySqlConnectionOpenedContext(MySqlConnection connection, MySqlConnectionOpenedConditions conditions)
{
Connection = connection;
Conditions = conditions;
}
}
8 changes: 6 additions & 2 deletions src/MySqlConnector/MySqlDataSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public sealed class MySqlDataSource : DbDataSource
/// <param name="connectionString">The connection string for the MySQL Server. This parameter is required.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="connectionString"/> is <c>null</c>.</exception>
public MySqlDataSource(string connectionString)
: this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default)
: this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default, default)
{
}

Expand All @@ -31,7 +31,8 @@ internal MySqlDataSource(string connectionString,
Func<MySqlProvidePasswordContext, CancellationToken, ValueTask<string>>? periodicPasswordProvider,
TimeSpan periodicPasswordProviderSuccessRefreshInterval,
TimeSpan periodicPasswordProviderFailureRefreshInterval,
ZstandardPlugin? zstandardPlugin)
ZstandardPlugin? zstandardPlugin,
MySqlConnectionOpenedCallback? connectionOpenedCallback)
{
m_connectionString = connectionString;
LoggingConfiguration = loggingConfiguration;
Expand All @@ -40,6 +41,7 @@ internal MySqlDataSource(string connectionString,
m_remoteCertificateValidationCallback = remoteCertificateValidationCallback;
m_logger = loggingConfiguration.DataSourceLogger;
m_zstandardPlugin = zstandardPlugin;
m_connectionOpenedCallback = connectionOpenedCallback;

Pool = ConnectionPool.CreatePool(m_connectionString, LoggingConfiguration, name);
m_id = Interlocked.Increment(ref s_lastId);
Expand Down Expand Up @@ -142,6 +144,7 @@ protected override DbConnection CreateDbConnection()
ProvideClientCertificatesCallback = m_clientCertificatesCallback,
ProvidePasswordCallback = m_providePasswordCallback,
RemoteCertificateValidationCallback = m_remoteCertificateValidationCallback,
ConnectionOpenedCallback = m_connectionOpenedCallback,
};
}

Expand Down Expand Up @@ -225,6 +228,7 @@ private string ProvidePasswordFromInitialRefreshTask(MySqlProvidePasswordContext
private readonly TimeSpan m_periodicPasswordProviderSuccessRefreshInterval;
private readonly TimeSpan m_periodicPasswordProviderFailureRefreshInterval;
private readonly ZstandardPlugin? m_zstandardPlugin;
private readonly MySqlConnectionOpenedCallback? m_connectionOpenedCallback;
private readonly MySqlProvidePasswordContext? m_providePasswordContext;
private readonly CancellationTokenSource? m_passwordProviderTimerCancellationTokenSource;
private readonly Timer? m_passwordProviderTimer;
Expand Down
15 changes: 14 additions & 1 deletion src/MySqlConnector/MySqlDataSourceBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ public MySqlDataSourceBuilder UseRemoteCertificateValidationCallback(RemoteCerti
return this;
}

/// <summary>
/// Adds a callback that is invoked when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
/// <param name="callback">The callback to invoke.</param>
/// <returns>This builder, so that method calls can be chained.</returns>
public MySqlDataSourceBuilder UseConnectionOpenedCallback(MySqlConnectionOpenedCallback callback)
{
m_connectionOpenedCallback += callback;
return this;
}

/// <summary>
/// Builds a <see cref="MySqlDataSource"/> which is ready for use.
/// </summary>
Expand All @@ -104,7 +115,8 @@ public MySqlDataSource Build()
m_periodicPasswordProvider,
m_periodicPasswordProviderSuccessRefreshInterval,
m_periodicPasswordProviderFailureRefreshInterval,
ZstandardPlugin
ZstandardPlugin,
m_connectionOpenedCallback
);
}

Expand All @@ -122,4 +134,5 @@ public MySqlDataSource Build()
private Func<MySqlProvidePasswordContext, CancellationToken, ValueTask<string>>? m_periodicPasswordProvider;
private TimeSpan m_periodicPasswordProviderSuccessRefreshInterval;
private TimeSpan m_periodicPasswordProviderFailureRefreshInterval;
private MySqlConnectionOpenedCallback? m_connectionOpenedCallback;
}
32 changes: 32 additions & 0 deletions tests/IntegrationTests/TransactionScopeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,38 @@ public void Bug1348()

Assert.True(rollbacked, $"First branch transaction '{xid}1' not rolled back");
}

[Fact]
public void ConnectionOpenedCallbackAutoEnlistInTransaction()
{
var connectionOpenedCallbackCount = 0;
var connectionOpenedConditions = MySqlConnectionOpenedConditions.None;
using var dataSource = new MySqlDataSourceBuilder(AppConfig.ConnectionString)
.UseConnectionOpenedCallback((ctx, token) =>
{
connectionOpenedCallbackCount++;
connectionOpenedConditions = ctx.Conditions;
return default;
})
.Build();

using (var transactionScope = new TransactionScope())
{
using (var conn = dataSource.OpenConnection())
{
Assert.Equal(1, connectionOpenedCallbackCount);
Assert.Equal(MySqlConnectionOpenedConditions.New, connectionOpenedConditions);
}

using (var conn = dataSource.OpenConnection())
{
Assert.Equal(2, connectionOpenedCallbackCount);
Assert.Equal(MySqlConnectionOpenedConditions.None, connectionOpenedConditions);
}

transactionScope.Complete();
}
}
#endif

readonly DatabaseFixture m_database;
Expand Down
Loading