diff --git a/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs b/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs index 78f7fcc09f1..8229ba660b5 100644 --- a/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs +++ b/src/MongoDB.Driver/Core/Connections/TcpStreamFactory.cs @@ -165,10 +165,17 @@ private void ConfigureConnectedSocket(Socket socket) private void Connect(Socket socket, EndPoint endPoint, CancellationToken cancellationToken) { - var isSocketDisposed = false; + var callbackState = new OperationCallbackState(socket); using var timeoutCancellationTokenSource = new CancellationTokenSource(_settings.ConnectTimeout); using var combinedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token); - using var cancellationSubscription = combinedCancellationTokenSource.Token.Register(DisposeSocket); + using var cancellationSubscription = combinedCancellationTokenSource.Token.Register(state => + { + var operationState = (OperationCallbackState)state; + if (operationState.TryChangeStatusFromInProgress(OperationCallbackState.OperationStatus.Interrupted)) + { + DisposeSocket(operationState.Subject); + } + }, callbackState); try { @@ -185,13 +192,14 @@ private void Connect(Socket socket, EndPoint endPoint, CancellationToken cancell #else socket.Connect(endPoint); #endif + if (!callbackState.TryChangeStatusFromInProgress(OperationCallbackState.OperationStatus.Done)) + { + throw new ObjectDisposedException(nameof(Socket)); + } } catch { - if (!isSocketDisposed) - { - DisposeSocket(); - } + DisposeSocket(socket); cancellationToken.ThrowIfCancellationRequested(); if (timeoutCancellationTokenSource.IsCancellationRequested) @@ -202,9 +210,8 @@ private void Connect(Socket socket, EndPoint endPoint, CancellationToken cancell throw; } - void DisposeSocket() + static void DisposeSocket(Socket socket) { - isSocketDisposed = true; try { socket.Dispose(); diff --git a/src/MongoDB.Driver/Core/Misc/OperationCallbackState.cs b/src/MongoDB.Driver/Core/Misc/OperationCallbackState.cs new file mode 100644 index 00000000000..f04794eccad --- /dev/null +++ b/src/MongoDB.Driver/Core/Misc/OperationCallbackState.cs @@ -0,0 +1,35 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Threading; + +namespace MongoDB.Driver.Core.Misc; + +internal sealed class OperationCallbackState(T subject) +{ + private int _status = (int)OperationStatus.InProgress; + + public OperationStatus Status => (OperationStatus)_status; + public T Subject => subject; + public bool TryChangeStatusFromInProgress(OperationStatus newState) => + Interlocked.CompareExchange(ref _status, (int)newState, (int)OperationStatus.InProgress) == (int)OperationStatus.InProgress; + + public enum OperationStatus + { + InProgress = 0, + Done, + Interrupted, + } +} diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 06e0d4a5ac2..11bc4edaed7 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -287,25 +287,25 @@ private static void ExecuteOperationWithTimeout(Stream stream, TState st throw new TimeoutException(); } - StreamDisposeCallbackState callbackState = null; + OperationCallbackState callbackState = null; Timer timer = null; CancellationTokenRegistration cancellationSubscription = default; if (timeoutMs > 0) { - callbackState = new StreamDisposeCallbackState(stream); + callbackState = new OperationCallbackState(stream); timer = new Timer(DisposeStreamCallback, callbackState, timeoutMs, Timeout.Infinite); } if (cancellationToken.CanBeCanceled) { - callbackState ??= new StreamDisposeCallbackState(stream); + callbackState ??= new OperationCallbackState(stream); cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState); } try { operation(stream, state); - if (callbackState?.TryChangeStateFromInProgress(OperationState.Done) == false) + if (callbackState?.TryChangeStatusFromInProgress(OperationCallbackState.OperationStatus.Done) == false) { // If the state can't be changed - then the stream was/will be disposed, throw here throw new IOException(); @@ -313,7 +313,7 @@ private static void ExecuteOperationWithTimeout(Stream stream, TState st } catch (Exception ex) { - if (callbackState?.OperationState == OperationState.Interrupted) + if (callbackState?.Status == OperationCallbackState.OperationStatus.Interrupted) { cancellationToken.ThrowIfCancellationRequested(); throw new TimeoutException(); @@ -334,8 +334,8 @@ private static void ExecuteOperationWithTimeout(Stream stream, TState st static void DisposeStreamCallback(object state) { - var disposeCallbackState = (StreamDisposeCallbackState)state; - if (!disposeCallbackState.TryChangeStateFromInProgress(OperationState.Interrupted)) + var disposeCallbackState = (OperationCallbackState)state; + if (!disposeCallbackState.TryChangeStatusFromInProgress(OperationCallbackState.OperationStatus.Interrupted)) { // If the state can't be changed - then I/O had already succeeded return; @@ -343,7 +343,7 @@ static void DisposeStreamCallback(object state) try { - disposeCallbackState.Stream.Dispose(); + disposeCallbackState.Subject.Dispose(); } catch (Exception) { @@ -351,22 +351,5 @@ static void DisposeStreamCallback(object state) } } } - - private record StreamDisposeCallbackState(Stream Stream) - { - private int _operationState = (int)OperationState.InProgress; - - public OperationState OperationState => (OperationState)_operationState; - - public bool TryChangeStateFromInProgress(OperationState newState) => - Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress; - } - - private enum OperationState - { - InProgress = 0, - Done, - Interrupted, - } } } diff --git a/tests/MongoDB.TestHelpers/XunitExtensions/TimeoutEnforcing/UnobservedExceptionTestDiscoverer.cs b/tests/MongoDB.TestHelpers/XunitExtensions/TimeoutEnforcing/UnobservedExceptionTestDiscoverer.cs index b4fdc098c48..d38c4d5ca81 100644 --- a/tests/MongoDB.TestHelpers/XunitExtensions/TimeoutEnforcing/UnobservedExceptionTestDiscoverer.cs +++ b/tests/MongoDB.TestHelpers/XunitExtensions/TimeoutEnforcing/UnobservedExceptionTestDiscoverer.cs @@ -42,13 +42,16 @@ public UnobservedExceptionTestDiscoverer(IMessageSink diagnosticsMessageSink) public IEnumerable Discover(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, IAttributeInfo factAttribute) { - return [new XunitTestCase(_diagnosticsMessageSink, TestMethodDisplay.Method, TestMethodDisplayOptions.All, testMethod) + var testCase = new XunitTestCase(_diagnosticsMessageSink, TestMethodDisplay.Method, TestMethodDisplayOptions.All, testMethod); + if (!testCase.Traits.TryGetValue("Category", out var categories)) { - Traits = - { - { "Category", ["UnobservedExceptionTracking"] } - } - }]; + categories = new List(); + testCase.Traits.Add("Category", categories); + } + + categories.Add("UnobservedExceptionTracking"); + + return [testCase]; } void UnobservedTaskExceptionEventHandler(object sender, UnobservedTaskExceptionEventArgs unobservedException) =>