Skip to content

Commit a947ded

Browse files
authored
Fix | AsyncHelper.WaitForCompletion leaks unobserved exceptions (#692)
1 parent 4107f24 commit a947ded

File tree

5 files changed

+89
-22
lines changed

5 files changed

+89
-22
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
using System.Transactions;
1818
using Microsoft.Data.Common;
1919

20+
[assembly: InternalsVisibleTo("FunctionalTests")]
21+
2022
namespace Microsoft.Data.SqlClient
2123
{
2224
internal static class AsyncHelper
@@ -204,6 +206,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout
204206
}
205207
if (!task.IsCompleted)
206208
{
209+
task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception
207210
if (onTimeout != null)
208211
{
209212
onTimeout();

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
using System.Threading.Tasks;
1818
using SysTx = System.Transactions;
1919

20+
[assembly: InternalsVisibleTo("FunctionalTests")]
21+
2022
namespace Microsoft.Data.SqlClient
2123
{
2224
using Microsoft.Data.Common;
@@ -189,6 +191,7 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout
189191
}
190192
if (!task.IsCompleted)
191193
{
194+
task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception
192195
if (onTimeout != null)
193196
{
194197
onTimeout();

src/Microsoft.Data.SqlClient/tests/FunctionalTests/BaseProviderAsyncTest/BaseProviderAsyncTest.cs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
1313
{
1414
public static class BaseProviderAsyncTest
1515
{
16+
private static void AssertTaskFaults(Task t)
17+
{
18+
Assert.ThrowsAny<Exception>(() => t.Wait(TimeSpan.FromMilliseconds(1)));
19+
}
20+
1621
[Fact]
1722
public static void TestDbConnection()
1823
{
@@ -37,8 +42,8 @@ public static void TestDbConnection()
3742
{
3843
Fail = true
3944
};
40-
connectionFail.OpenAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
41-
connectionFail.OpenAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
45+
AssertTaskFaults(connectionFail.OpenAsync());
46+
AssertTaskFaults(connectionFail.OpenAsync(source.Token));
4247

4348
// Verify base implementation does not call Open when passed an already cancelled cancellation token
4449
source.Cancel();
@@ -90,14 +95,14 @@ public static void TestDbCommand()
9095
{
9196
Fail = true
9297
};
93-
commandFail.ExecuteNonQueryAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
94-
commandFail.ExecuteNonQueryAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
95-
commandFail.ExecuteReaderAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
96-
commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
97-
commandFail.ExecuteReaderAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
98-
commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
99-
commandFail.ExecuteScalarAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
100-
commandFail.ExecuteScalarAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
98+
AssertTaskFaults(commandFail.ExecuteNonQueryAsync());
99+
AssertTaskFaults(commandFail.ExecuteNonQueryAsync(source.Token));
100+
AssertTaskFaults(commandFail.ExecuteReaderAsync());
101+
AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess));
102+
AssertTaskFaults(commandFail.ExecuteReaderAsync(source.Token));
103+
AssertTaskFaults(commandFail.ExecuteReaderAsync(CommandBehavior.SequentialAccess, source.Token));
104+
AssertTaskFaults(commandFail.ExecuteScalarAsync());
105+
AssertTaskFaults(commandFail.ExecuteScalarAsync(source.Token));
101106

102107
// Verify base implementation does not call Open when passed an already cancelled cancellation token
103108
source.Cancel();
@@ -116,17 +121,17 @@ public static void TestDbCommand()
116121
source = new CancellationTokenSource();
117122
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
118123
Task result = command.ExecuteNonQueryAsync(source.Token);
119-
Assert.True(result.IsFaulted, "Task result should be faulted");
124+
Assert.True(result.Exception != null, "Task result should be faulted");
120125

121126
source = new CancellationTokenSource();
122127
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
123128
result = command.ExecuteReaderAsync(source.Token);
124-
Assert.True(result.IsFaulted, "Task result should be faulted");
129+
Assert.True(result.Exception != null, "Task result should be faulted");
125130

126131
source = new CancellationTokenSource();
127132
Task.Factory.StartNew(() => { command.WaitForWaitingForCancel(); source.Cancel(); });
128133
result = command.ExecuteScalarAsync(source.Token);
129-
Assert.True(result.IsFaulted, "Task result should be faulted");
134+
Assert.True(result.Exception != null, "Task result should be faulted");
130135
}
131136

132137
[Fact]
@@ -155,9 +160,9 @@ public static void TestDbDataReader()
155160

156161
GetFieldValueAsync<object>(reader, 2, DBNull.Value);
157162
GetFieldValueAsync<DBNull>(reader, 2, DBNull.Value);
158-
reader.GetFieldValueAsync<int?>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
159-
reader.GetFieldValueAsync<string>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
160-
reader.GetFieldValueAsync<bool>(2).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
163+
AssertTaskFaults(reader.GetFieldValueAsync<int?>(2));
164+
AssertTaskFaults(reader.GetFieldValueAsync<string>(2));
165+
AssertTaskFaults(reader.GetFieldValueAsync<bool>(2));
161166
AssertEqualsWithDescription("GetValue", reader.LastCommand, "Last command was not as expected");
162167

163168
result = reader.ReadAsync();
@@ -174,12 +179,12 @@ public static void TestDbDataReader()
174179
Assert.False(result.Result, "Should NOT have received a Result from NextResultAsync");
175180

176181
MockDataReader readerFail = new MockDataReader { Results = query.GetEnumerator(), Fail = true };
177-
readerFail.ReadAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
178-
readerFail.ReadAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
179-
readerFail.NextResultAsync().ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
180-
readerFail.NextResultAsync(source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
181-
readerFail.GetFieldValueAsync<object>(0).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
182-
readerFail.GetFieldValueAsync<object>(0, source.Token).ContinueWith((t) => { }, TaskContinuationOptions.OnlyOnFaulted).Wait();
182+
AssertTaskFaults(readerFail.ReadAsync());
183+
AssertTaskFaults(readerFail.ReadAsync(source.Token));
184+
AssertTaskFaults(readerFail.NextResultAsync());
185+
AssertTaskFaults(readerFail.NextResultAsync(source.Token));
186+
AssertTaskFaults(readerFail.GetFieldValueAsync<object>(0));
187+
AssertTaskFaults(readerFail.GetFieldValueAsync<object>(0, source.Token));
183188

184189
source.Cancel();
185190
reader.LastCommand = "Nothing";

src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
<Compile Include="SqlCredentialTest.cs" />
4343
<Compile Include="SqlDataRecordTest.cs" />
4444
<Compile Include="SqlExceptionTest.cs" />
45+
<Compile Include="SqlHelperTest.cs" />
4546
<Compile Include="SqlParameterTest.cs" />
4647
<Compile Include="SqlClientFactoryTest.cs" />
4748
<Compile Include="SqlErrorCollectionTest.cs" />
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Xunit;
9+
10+
namespace Microsoft.Data.SqlClient.Tests
11+
{
12+
public class SqlHelperTest
13+
{
14+
private void TimeOutATask()
15+
{
16+
TaskCompletionSource<bool> tcs = new TaskCompletionSource<bool>();
17+
AsyncHelper.WaitForCompletion(tcs.Task, 1); //Will time out as task uncompleted
18+
tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error
19+
}
20+
21+
private Exception UnwrapException(Exception e)
22+
{
23+
return e?.InnerException != null ? UnwrapException(e.InnerException) : e;
24+
}
25+
26+
[Fact]
27+
public void WaitForCompletion_DoesNotCreateUnobservedException()
28+
{
29+
var unobservedExceptionHappenedEvent = new AutoResetEvent(false);
30+
Exception unhandledException = null;
31+
void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a)
32+
{ unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); }
33+
34+
TaskScheduler.UnobservedTaskException += handleUnobservedException;
35+
36+
try
37+
{
38+
TimeOutATask(); //Create the task in another function so the task has no reference remaining
39+
GC.Collect(); //Force collection of unobserved task
40+
GC.WaitForPendingFinalizers();
41+
42+
bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1);
43+
if (unobservedExceptionHappend) //Save doing string interpolation in the happy case
44+
{
45+
var e = UnwrapException(unhandledException);
46+
Assert.False(true, $"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\"");
47+
}
48+
}
49+
finally
50+
{
51+
TaskScheduler.UnobservedTaskException -= handleUnobservedException;
52+
}
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)