Skip to content

Commit 2462bf5

Browse files
authored
Add UseSTASynchronizationContext to STATestMethod (#7192)
2 parents ebcfad4 + 5ff9fc1 commit 2462bf5

File tree

4 files changed

+127
-4
lines changed

4 files changed

+127
-4
lines changed

src/TestFramework/TestFramework/Attributes/TestMethod/STATestMethodAttribute.cs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ public STATestMethodAttribute(TestMethodAttribute testMethodAttribute)
2929
: base(testMethodAttribute.DeclaringFilePath, testMethodAttribute.DeclaringLineNumber ?? -1)
3030
=> _testMethodAttribute = testMethodAttribute;
3131

32+
/// <summary>
33+
/// Gets or sets a value indicating whether the attribute will set a <see cref="SynchronizationContext"/> that preserves the same
34+
/// STA thread for async continuations.
35+
/// The default is <see langword="false"/>.
36+
/// </summary>
37+
public bool UseSTASynchronizationContext { get; set; }
38+
3239
/// <summary>
3340
/// The core execution of STA test method, which happens on the STA thread.
3441
/// </summary>
@@ -38,18 +45,39 @@ protected virtual Task<TestResult[]> ExecuteCoreAsync(ITestMethod testMethod)
3845
=> _testMethodAttribute is null ? base.ExecuteAsync(testMethod) : _testMethodAttribute.ExecuteAsync(testMethod);
3946

4047
/// <inheritdoc />
41-
public override Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
48+
public override async Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
4249
{
50+
if (UseSTASynchronizationContext)
51+
{
52+
SynchronizationContext? originalContext = SynchronizationContext.Current;
53+
var syncContext = new SingleThreadedSTASynchronizationContext();
54+
try
55+
{
56+
SynchronizationContext.SetSynchronizationContext(syncContext);
57+
58+
// The yield ensures that we switch to the STA thread created by SingleThreadedSTASynchronizationContext.
59+
await Task.Yield();
60+
TestResult[] testResults = await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
61+
return testResults;
62+
}
63+
finally
64+
{
65+
SynchronizationContext.SetSynchronizationContext(originalContext);
66+
syncContext.Complete();
67+
syncContext.Dispose();
68+
}
69+
}
70+
4371
if (Thread.CurrentThread.GetApartmentState() == ApartmentState.STA)
4472
{
45-
return ExecuteCoreAsync(testMethod);
73+
return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
4674
}
4775

4876
#if !NETFRAMEWORK
4977
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
5078
{
5179
// TODO: Throw?
52-
return ExecuteCoreAsync(testMethod);
80+
return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
5381
}
5482
#endif
5583

@@ -61,6 +89,6 @@ public override Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
6189
t.SetApartmentState(ApartmentState.STA);
6290
t.Start();
6391
t.Join();
64-
return Task.FromResult(results!);
92+
return results!;
6593
}
6694
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
namespace Microsoft.VisualStudio.TestTools.UnitTesting;
5+
6+
internal sealed class SingleThreadedSTASynchronizationContext : SynchronizationContext, IDisposable
7+
{
8+
private readonly BlockingCollection<Action> _queue = [];
9+
private readonly Thread _thread;
10+
11+
public SingleThreadedSTASynchronizationContext()
12+
{
13+
#if !NETFRAMEWORK
14+
if (!OperatingSystem.IsWindows())
15+
{
16+
throw new NotSupportedException("SingleThreadedSTASynchronizationContext is only supported on Windows.");
17+
}
18+
#endif
19+
20+
_thread = new Thread(() =>
21+
{
22+
SetSynchronizationContext(this);
23+
foreach (Action callback in _queue.GetConsumingEnumerable())
24+
{
25+
callback();
26+
}
27+
})
28+
{
29+
IsBackground = true,
30+
};
31+
_thread.SetApartmentState(ApartmentState.STA);
32+
_thread.Start();
33+
}
34+
35+
public override void Post(SendOrPostCallback d, object? state)
36+
=> _queue.Add(() => d(state));
37+
38+
public override void Send(SendOrPostCallback d, object? state)
39+
{
40+
if (Environment.CurrentManagedThreadId == _thread.ManagedThreadId)
41+
{
42+
d(state);
43+
}
44+
else
45+
{
46+
using var done = new ManualResetEventSlim();
47+
_queue.Add(() =>
48+
{
49+
try
50+
{
51+
d(state);
52+
}
53+
finally
54+
{
55+
done.Set();
56+
}
57+
});
58+
done.Wait();
59+
}
60+
}
61+
62+
public void Complete() => _queue.CompleteAdding();
63+
64+
public void Dispose() => _queue.Dispose();
65+
}

src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsExactInstance
4343
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsNotExactInstanceOfTypeInterpolatedStringHandler<TArg>
4444
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler
4545
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsNotExactInstanceOfTypeInterpolatedStringHandler
46+
Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.get -> bool
47+
Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.set -> void
4648
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.ContainsSingle(System.Func<object?, bool>! predicate, System.Collections.IEnumerable! collection, string? message = "", string! predicateExpression = "", string! collectionExpression = "") -> object?
4749
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, string? message = "", string! valueExpression = "") -> void
4850
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, ref Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler message, string! valueExpression = "") -> void
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
namespace MSTest.SelfRealExamples.UnitTests;
5+
6+
[TestClass]
7+
public class STATestMethodSyncContext
8+
{
9+
[STATestMethod]
10+
[OSCondition(OperatingSystems.Windows)]
11+
public void STAByDefaultDoesNotUseSynchronizationContext()
12+
{
13+
Assert.IsNull(SynchronizationContext.Current);
14+
Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());
15+
}
16+
17+
[STATestMethod(UseSTASynchronizationContext = true)]
18+
[OSCondition(OperatingSystems.Windows)]
19+
public async Task STAWithSynchronizationContextIsCorrect()
20+
{
21+
Assert.IsNotNull(SynchronizationContext.Current);
22+
Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());
23+
24+
await Task.Delay(100);
25+
26+
Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());
27+
}
28+
}

0 commit comments

Comments
 (0)