Skip to content

Commit 8fdc279

Browse files
authored
Make the MsTest source generator emit an override modifier when the base class has a TestContext property (#1440)
* Add and adapt tests for TestContext property defined in base class * Emit override modifier if base class has test context * Seal overridden TestContextProperty and call base if appropriate
1 parent 72a19fc commit 8fdc279

8 files changed

+236
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
Item1: UsesVerify.g.cs,
3+
Item2:
4+
//-----------------------------------------------------
5+
// This code was generated by a tool.
6+
//
7+
// Changes to this file may cause incorrect behavior
8+
// and will be lost when the code is regenerated.
9+
// <auto-generated />
10+
//-----------------------------------------------------
11+
12+
partial class Derived
13+
{
14+
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Verify.MSTest.SourceGenerator", "1.0.0.0")]
15+
public sealed override global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext
16+
{
17+
get => global::VerifyMSTest.Verifier.CurrentTestContext.Value!.TestContext;
18+
set => global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());
19+
}
20+
}
21+
22+
}

src/Verify.MSTest.SourceGenerator.Tests/InheritanceTests.HasAttributeOnDerivedClassAndPropertyManuallyDefinedInBase.verified.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
partial class Derived
1313
{
1414
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Verify.MSTest.SourceGenerator", "1.0.0.0")]
15-
public global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext
15+
public sealed override global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext
1616
{
17-
get => global::VerifyMSTest.Verifier.CurrentTestContext.Value!.TestContext;
18-
set => global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());
17+
get => base.TestContext;
18+
set
19+
{
20+
global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());
21+
base.TestContext = value;
22+
}
1923
}
2024
}
2125

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
Item1: UsesVerify.g.cs,
3+
Item2:
4+
//-----------------------------------------------------
5+
// This code was generated by a tool.
6+
//
7+
// Changes to this file may cause incorrect behavior
8+
// and will be lost when the code is regenerated.
9+
// <auto-generated />
10+
//-----------------------------------------------------
11+
12+
partial class Derived
13+
{
14+
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Verify.MSTest.SourceGenerator", "1.0.0.0")]
15+
public sealed override global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext
16+
{
17+
get => base.TestContext;
18+
set
19+
{
20+
global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());
21+
base.TestContext = value;
22+
}
23+
}
24+
}
25+
26+
}

src/Verify.MSTest.SourceGenerator.Tests/InheritanceTests.cs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,51 @@ public partial class Derived : Base
7676
}
7777
""";
7878

79-
return VerifyGenerator(TestDriver.Run(source), ["CS0108"]);
79+
return VerifyGenerator(TestDriver.Run(source), ["CS0506"]);
80+
}
81+
82+
[Fact]
83+
public Task HasAttributeOnDerivedClassAndVirtualPropertyManuallyDefinedInBase()
84+
{
85+
var source = """
86+
using Microsoft.VisualStudio.TestTools.UnitTesting;
87+
using VerifyMSTest;
88+
89+
public class Base
90+
{
91+
public virtual TestContext TestContext { get; set; }
92+
}
93+
94+
[TestClass]
95+
[UsesVerify]
96+
public partial class Derived : Base
97+
{
98+
}
99+
""";
100+
101+
return VerifyGenerator(TestDriver.Run(source));
102+
}
103+
104+
[Fact]
105+
public Task HasAttributeOnDerivedClassAndAbstractPropertyManuallyDefinedInBase()
106+
{
107+
var source = """
108+
using Microsoft.VisualStudio.TestTools.UnitTesting;
109+
using VerifyMSTest;
110+
111+
public abstract class Base
112+
{
113+
public abstract TestContext TestContext { get; set; }
114+
}
115+
116+
[TestClass]
117+
[UsesVerify]
118+
public partial class Derived : Base
119+
{
120+
}
121+
""";
122+
123+
return VerifyGenerator(TestDriver.Run(source));
80124
}
81125

82126
[Fact]

src/Verify.MSTest.SourceGenerator.Tests/TestBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ private protected async Task VerifyGenerator(GeneratorDriverResults results, IRe
1212
var cached = results.CachedRun;
1313
Output.WriteLine($"Cached re-run of generators took: {cached.TimingInfo.ElapsedTime}");
1414

15-
if (expectedDiagnostics != null)
15+
if (expectedDiagnostics == null || expectedDiagnostics.Count == 0)
16+
{
17+
Assert.Empty(results.outputCompilation.GetDiagnostics());
18+
}
19+
else
1620
{
1721
foreach (var diagnostic in results.outputCompilation.GetDiagnostics())
1822
{

src/Verify.MSTest.SourceGenerator/ClassToGenerate.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,29 @@
88
/// The built in equality and hash code implementations won't work because this type includes a
99
/// collection (which has reference equality semantics), so we must implement them ourselves.
1010
/// </remarks>
11-
readonly record struct ClassToGenerate(string? Namespace, string ClassName, ParentClass[] ParentClasses)
11+
readonly record struct ClassToGenerate(
12+
string? Namespace,
13+
string ClassName,
14+
ClassToGenerate.PropertyFlags TestContextPropertyFlags,
15+
ParentClass[] ParentClasses)
1216
{
17+
[Flags]
18+
public enum PropertyFlags
19+
{
20+
None = 0b00,
21+
Override = 0b01,
22+
CallBase = 0b10,
23+
}
24+
1325
public string? Namespace { get; } = Namespace;
1426
public string ClassName { get; } = ClassName;
1527
public ParentClass[] ParentClasses { get; } = ParentClasses;
28+
public PropertyFlags TestContextPropertyFlags { get; } = TestContextPropertyFlags;
1629

1730
public bool Equals(ClassToGenerate other) =>
1831
Namespace == other.Namespace &&
1932
ClassName == other.ClassName &&
33+
TestContextPropertyFlags == other.TestContextPropertyFlags &&
2034
ParentClasses.SequenceEqual(other.ParentClasses);
2135

2236
public override int GetHashCode()
@@ -27,6 +41,7 @@ public override int GetHashCode()
2741
var hash = 1430287;
2842
hash = hash * 7302013 ^ (Namespace ?? string.Empty).GetHashCode();
2943
hash = hash * 7302013 ^ ClassName.GetHashCode();
44+
hash = hash * 7302013 ^ TestContextPropertyFlags.GetHashCode();
3045

3146
// Include (up to) the last 8 elements in the hash code to balance performance and specificity.
3247
// The runtime also does this for structural equality; see

src/Verify.MSTest.SourceGenerator/Emitter.cs

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using PropertyFlags = ClassToGenerate.PropertyFlags;
2+
13
class Emitter
24
{
35
const string AutoGenerationHeader = """
@@ -10,6 +12,9 @@ class Emitter
1012
//-----------------------------------------------------
1113
""";
1214

15+
const string SetVerifierTestContext =
16+
"global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());";
17+
1318
static readonly string GeneratedCodeAttribute =
1419
$"[global::System.CodeDom.Compiler.GeneratedCodeAttribute(\"{typeof(Emitter).Assembly.GetName().Name}\", \"{typeof(Emitter).Assembly.GetName().Version}\")]";
1520

@@ -52,20 +57,17 @@ void WriteParentTypes(ClassToGenerate classToGenerate)
5257
}
5358
}
5459

55-
void WriteClass(ClassToGenerate classToGenerate) =>
60+
void WriteClass(ClassToGenerate classToGenerate)
61+
{
5662
builder.Append("partial class ").AppendLine(classToGenerate.ClassName)
5763
.AppendLine("{")
58-
.IncreaseIndent()
59-
.AppendLine(GeneratedCodeAttribute)
60-
.AppendLine("public global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext")
61-
.AppendLine("{")
62-
.IncreaseIndent()
63-
.AppendLine("get => global::VerifyMSTest.Verifier.CurrentTestContext.Value!.TestContext;")
64-
.AppendLine("set => global::VerifyMSTest.Verifier.CurrentTestContext.Value = new global::VerifyMSTest.TestExecutionContext(value, GetType());")
65-
.DecreaseIndent()
66-
.AppendLine("}")
64+
.IncreaseIndent();
65+
AppendTestContextProperty(
66+
classToGenerate.TestContextPropertyFlags);
67+
builder
6768
.DecreaseIndent()
6869
.AppendLine("}");
70+
}
6971

7072
public string GenerateExtensionClasses(IReadOnlyCollection<ClassToGenerate> classesToGenerate, Cancel cancel)
7173
{
@@ -81,4 +83,55 @@ public string GenerateExtensionClasses(IReadOnlyCollection<ClassToGenerate> clas
8183

8284
return builder.ToString();
8385
}
86+
87+
public void AppendTestContextProperty(PropertyFlags flags)
88+
{
89+
builder.AppendLine(GeneratedCodeAttribute)
90+
.Append("public ")
91+
.Append(GetModifiers(flags))
92+
.AppendLine("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContext TestContext")
93+
.AppendLine("{")
94+
.IncreaseIndent()
95+
.Append($"get => ").AppendLine(GetterBody(flags));
96+
AppendSetter(flags);
97+
builder
98+
.DecreaseIndent()
99+
.AppendLine("}");
100+
}
101+
102+
void AppendSetter(PropertyFlags flags)
103+
{
104+
if (flags.HasFlag(PropertyFlags.CallBase))
105+
{
106+
AppendCallBaseSetter();
107+
}
108+
else
109+
{
110+
AppendDefaultSetter();
111+
}
112+
}
113+
114+
void AppendDefaultSetter() =>
115+
builder
116+
.Append("set => ").AppendLine(SetVerifierTestContext);
117+
118+
private void AppendCallBaseSetter() =>
119+
builder
120+
.AppendLine("set")
121+
.AppendLine("{")
122+
.IncreaseIndent()
123+
.AppendLine(SetVerifierTestContext)
124+
.AppendLine("base.TestContext = value;")
125+
.DecreaseIndent()
126+
.AppendLine("}");
127+
128+
static string GetterBody(PropertyFlags flags) =>
129+
flags.HasFlag(PropertyFlags.CallBase)
130+
? "base.TestContext;"
131+
: "global::VerifyMSTest.Verifier.CurrentTestContext.Value!.TestContext;";
132+
133+
static string GetModifiers(PropertyFlags flags) =>
134+
flags.HasFlag(PropertyFlags.Override)
135+
? "sealed override "
136+
: string.Empty;
84137
}

src/Verify.MSTest.SourceGenerator/Parser.cs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,63 @@
1+
using System.Data;
2+
using PropertyFlags = ClassToGenerate.PropertyFlags;
3+
14
static class Parser
25
{
36
public static ClassToGenerate? Parse(INamedTypeSymbol typeSymbol, TypeDeclarationSyntax typeSyntax, Cancel cancel)
47
{
58
var ns = typeSymbol.GetNamespaceOrDefault();
69
var name = typeSyntax.GetTypeNameWithGenericParameters();
710
var parents = GetParentClasses(typeSyntax, cancel);
11+
var testContextPropertyFlags =
12+
GetDerivedPropertyFlagsGiven(
13+
BaseTestContextProperties(typeSymbol).FirstOrDefault());
14+
15+
return new ClassToGenerate(
16+
Namespace: ns,
17+
ClassName: name,
18+
TestContextPropertyFlags: testContextPropertyFlags,
19+
ParentClasses: parents);
20+
}
21+
22+
private static PropertyFlags GetDerivedPropertyFlagsGiven(
23+
IPropertySymbol? baseClassProperty)
24+
{
25+
var isAbstract = baseClassProperty?.IsAbstract;
26+
switch (isAbstract)
27+
{
28+
case true:
29+
return PropertyFlags.Override;
30+
case false:
31+
return PropertyFlags.Override | PropertyFlags.CallBase;
32+
case null:
33+
return PropertyFlags.None;
34+
}
35+
}
36+
37+
static IEnumerable<IPropertySymbol> BaseTestContextProperties(
38+
INamedTypeSymbol typeSymbol)
39+
=>
40+
from baseClass in BaseClassesOf(typeSymbol)
41+
from testContextProperty in GetTestContextProperty(baseClass)
42+
select testContextProperty;
43+
44+
45+
static IEnumerable<IPropertySymbol> GetTestContextProperty(INamedTypeSymbol typeSymbol) =>
46+
typeSymbol
47+
.GetMembers()
48+
.OfType<IPropertySymbol>()
49+
.Where(property =>
50+
property.Name == "TestContext" &&
51+
property.DeclaredAccessibility == Accessibility.Public);
852

9-
return new ClassToGenerate(ns, name, parents);
53+
static IEnumerable<INamedTypeSymbol> BaseClassesOf(INamedTypeSymbol typeSymbol)
54+
{
55+
var baseType = typeSymbol.BaseType;
56+
while (baseType?.TypeKind == TypeKind.Class)
57+
{
58+
yield return baseType;
59+
baseType = baseType.BaseType;
60+
}
1061
}
1162

1263
static ParentClass[] GetParentClasses(TypeDeclarationSyntax typeSyntax, Cancel cancel)

0 commit comments

Comments
 (0)