Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Current package versions:

## Unreleased

- (none)
- Add overrideable `AfterDisconnectAsync()` callback on `DefaultOptionsProvider` ([#2952 by philon-msft](https://github.com/StackExchange/StackExchange.Redis/pull/2952))

## 2.9.17

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ protected virtual string GetDefaultClientName() =>
/// <param name="log">The logger for the connection, to emit to the connection output log.</param>
public virtual Task AfterConnectAsync(ConnectionMultiplexer multiplexer, Action<string> log) => Task.CompletedTask;

/// <summary>
/// The action to perform, if any, immediately after a connection is closed.
/// </summary>
/// <param name="multiplexer">The multiplexer that just disconnected.</param>
public virtual Task AfterDisconnectAsync(ConnectionMultiplexer multiplexer) => Task.CompletedTask;

/// <summary>
/// Gets the default SSL "enabled or not" based on a set of endpoints.
/// Note: this setting then applies for *all* endpoints.
Expand Down
10 changes: 6 additions & 4 deletions src/StackExchange.Redis/ConfigurationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ public DefaultOptionsProvider Defaults

internal Func<ConnectionMultiplexer, Action<string>, Task> AfterConnectAsync => Defaults.AfterConnectAsync;

internal Func<ConnectionMultiplexer, Task> AfterDisconnectAsync => Defaults.AfterDisconnectAsync;

/// <summary>
/// Gets or sets whether connect/configuration timeouts should be explicitly notified via a TimeoutException.
/// </summary>
Expand Down Expand Up @@ -305,8 +307,8 @@ public bool HighIntegrity
/// <summary>
/// Supply a user certificate from a PEM file pair and enable TLS.
/// </summary>
/// <param name="userCertificatePath">The path for the the user certificate (commonly a .crt file).</param>
/// <param name="userKeyPath">The path for the the user key (commonly a .key file).</param>
/// <param name="userCertificatePath">The path for the user certificate (commonly a .crt file).</param>
/// <param name="userKeyPath">The path for the user key (commonly a .key file).</param>
public void SetUserPemCertificate(string userCertificatePath, string? userKeyPath = null)
{
CertificateSelectionCallback = CreatePemUserCertificateCallback(userCertificatePath, userKeyPath);
Expand All @@ -317,7 +319,7 @@ public void SetUserPemCertificate(string userCertificatePath, string? userKeyPat
/// <summary>
/// Supply a user certificate from a PFX file and optional password and enable TLS.
/// </summary>
/// <param name="userCertificatePath">The path for the the user certificate (commonly a .pfx file).</param>
/// <param name="userCertificatePath">The path for the user certificate (commonly a .pfx file).</param>
/// <param name="password">The password for the certificate file.</param>
public void SetUserPfxCertificate(string userCertificatePath, string? password = null)
{
Expand Down Expand Up @@ -383,7 +385,7 @@ private static bool CheckTrustedIssuer(X509Certificate2 certificateToValidate, X
chain.ChainPolicy.VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority;
chain.ChainPolicy.VerificationTime = chainToValidate?.ChainPolicy?.VerificationTime ?? DateTime.Now;
chain.ChainPolicy.UrlRetrievalTimeout = new TimeSpan(0, 0, 0);
// Ensure entended key usage checks are run and that we're observing a server TLS certificate
// Ensure intended key usage checks are run and that we're observing a server TLS certificate
chain.ChainPolicy.ApplicationPolicy.Add(_serverAuthOid);

chain.ChainPolicy.ExtraStore.Add(authority);
Expand Down
10 changes: 10 additions & 0 deletions src/StackExchange.Redis/ConnectionMultiplexer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2294,9 +2294,11 @@ public void Close(bool allowCommandsToComplete = true)
var quits = QuitAllServers();
WaitAllIgnoreErrors(quits);
}

DisposeAndClearServers();
OnCloseReaderWriter();
OnClosing(true);
RawConfig.AfterDisconnectAsync?.Invoke(this).Wait(SyncConnectTimeout(true));
Interlocked.Increment(ref _connectionCloseCount);
}

Expand All @@ -2306,7 +2308,11 @@ public void Close(bool allowCommandsToComplete = true)
/// <param name="allowCommandsToComplete">Whether to allow all in-queue commands to complete first.</param>
public async Task CloseAsync(bool allowCommandsToComplete = true)
{
if (_isDisposed) return;

OnClosing(false);
_isDisposed = true;
_profilingSessionProvider = null;
using (var tmp = pulse)
{
pulse = null;
Expand All @@ -2319,6 +2325,10 @@ public async Task CloseAsync(bool allowCommandsToComplete = true)
}

DisposeAndClearServers();
OnCloseReaderWriter();
OnClosing(true);
await RawConfig.AfterDisconnectAsync(this).ForAwait();
Interlocked.Increment(ref _connectionCloseCount);
}

private void DisposeAndClearServers()
Expand Down
1 change: 1 addition & 0 deletions src/StackExchange.Redis/PublicAPI/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,7 @@ static StackExchange.Redis.StreamPosition.Beginning.get -> StackExchange.Redis.R
static StackExchange.Redis.StreamPosition.NewMessages.get -> StackExchange.Redis.RedisValue
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.AbortOnConnectFail.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.AfterConnectAsync(StackExchange.Redis.ConnectionMultiplexer! multiplexer, System.Action<string!>! log) -> System.Threading.Tasks.Task!
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.AfterDisconnectAsync(StackExchange.Redis.ConnectionMultiplexer! multiplexer) -> System.Threading.Tasks.Task!
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.AllowAdmin.get -> bool
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.BacklogPolicy.get -> StackExchange.Redis.BacklogPolicy!
virtual StackExchange.Redis.Configuration.DefaultOptionsProvider.CheckCertificateRevocation.get -> bool
Expand Down
25 changes: 25 additions & 0 deletions tests/StackExchange.Redis.Tests/DefaultOptionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,31 @@ public async Task AfterConnectAsyncHandler()
Assert.Equal(1, provider.Calls);
}

public class TestAfterDisconnectOptionsProvider : DefaultOptionsProvider
{
public int Calls;

public override Task AfterDisconnectAsync(ConnectionMultiplexer muxer)
{
Interlocked.Increment(ref Calls);
return Task.CompletedTask;
}
}

[Fact]
public async Task AfterDisconnectAsyncHandler()
{
var options = ConfigurationOptions.Parse(GetConfiguration());
var provider = new TestAfterDisconnectOptionsProvider();
options.Defaults = provider;

await using var conn = await ConnectionMultiplexer.ConnectAsync(options, Writer);
await conn.CloseAsync();

Assert.False(conn.IsConnected);
Assert.Equal(1, provider.Calls);
}

public class TestClientNameOptionsProvider : DefaultOptionsProvider
{
protected override string GetDefaultClientName() => "Hey there";
Expand Down
Loading