Skip to content

Commit 811e23c

Browse files
authored
Merge pull request #107 from ZvonimirMatic/master
Source generation optimization
2 parents ffbcc8c + 6a31597 commit 811e23c

File tree

3 files changed

+173
-136
lines changed

3 files changed

+173
-136
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Microsoft.CodeAnalysis;
2+
using Microsoft.CodeAnalysis.CSharp.Syntax;
3+
4+
namespace EntityFrameworkCore.Projectables.Generator;
5+
6+
public class MemberDeclarationSyntaxAndCompilationEqualityComparer : IEqualityComparer<(MemberDeclarationSyntax, Compilation)>
7+
{
8+
public bool Equals((MemberDeclarationSyntax, Compilation) x, (MemberDeclarationSyntax, Compilation) y)
9+
{
10+
return GetMemberDeclarationSyntaxAndCompilationName(x.Item1, x.Item2) == GetMemberDeclarationSyntaxAndCompilationName(y.Item1, y.Item2);
11+
}
12+
13+
public int GetHashCode((MemberDeclarationSyntax, Compilation) obj)
14+
{
15+
return GetMemberDeclarationSyntaxAndCompilationName(obj.Item1, obj.Item2).GetHashCode();
16+
}
17+
18+
public static string GetMemberDeclarationSyntaxAndCompilationName(MemberDeclarationSyntax memberDeclarationSyntax, Compilation compilation)
19+
{
20+
return $"{compilation.AssemblyName}:{MemberDeclarationSyntaxEqualityComparer.GetMemberDeclarationSyntaxName(memberDeclarationSyntax)}";
21+
}
22+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using System.Text;
2+
using Microsoft.CodeAnalysis.CSharp.Syntax;
3+
4+
namespace EntityFrameworkCore.Projectables.Generator;
5+
6+
public class MemberDeclarationSyntaxEqualityComparer : IEqualityComparer<MemberDeclarationSyntax>
7+
{
8+
public bool Equals(MemberDeclarationSyntax x, MemberDeclarationSyntax y)
9+
{
10+
return GetMemberDeclarationSyntaxName(x) == GetMemberDeclarationSyntaxName(y);
11+
}
12+
13+
public int GetHashCode(MemberDeclarationSyntax obj)
14+
{
15+
return GetMemberDeclarationSyntaxName(obj).GetHashCode();
16+
}
17+
18+
public static string GetMemberDeclarationSyntaxName(MemberDeclarationSyntax memberDeclaration)
19+
{
20+
var sb = new StringBuilder();
21+
22+
// Get the member name
23+
if (memberDeclaration is MethodDeclarationSyntax methodDeclaration)
24+
{
25+
sb.Append(methodDeclaration.Identifier.Text);
26+
}
27+
else if (memberDeclaration is PropertyDeclarationSyntax propertyDeclaration)
28+
{
29+
sb.Append(propertyDeclaration.Identifier.Text);
30+
}
31+
else if (memberDeclaration is FieldDeclarationSyntax fieldDeclaration)
32+
{
33+
sb.Append(string.Join(", ", fieldDeclaration.Declaration.Variables.Select(v => v.Identifier.Text)));
34+
}
35+
36+
// Traverse up the tree to get containing type names
37+
var parent = memberDeclaration.Parent;
38+
while (parent != null)
39+
{
40+
switch (parent)
41+
{
42+
case NamespaceDeclarationSyntax namespaceDeclaration:
43+
sb.Insert(0, namespaceDeclaration.Name + ".");
44+
break;
45+
case ClassDeclarationSyntax classDeclaration:
46+
sb.Insert(0, classDeclaration.Identifier.Text + ".");
47+
break;
48+
case StructDeclarationSyntax structDeclaration:
49+
sb.Insert(0, structDeclaration.Identifier.Text + ".");
50+
break;
51+
case InterfaceDeclarationSyntax interfaceDeclaration:
52+
sb.Insert(0, interfaceDeclaration.Identifier.Text + ".");
53+
break;
54+
case EnumDeclarationSyntax enumDeclaration:
55+
sb.Insert(0, enumDeclaration.Identifier.Text + ".");
56+
break;
57+
}
58+
parent = parent.Parent;
59+
}
60+
61+
return sb.ToString();
62+
}
63+
}

src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs

Lines changed: 88 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,7 @@
33
using Microsoft.CodeAnalysis.CSharp;
44
using Microsoft.CodeAnalysis.CSharp.Syntax;
55
using Microsoft.CodeAnalysis.Text;
6-
using System;
7-
using System.Collections.Generic;
8-
using System.Collections.Immutable;
9-
using System.Diagnostics;
10-
using System.Linq;
11-
using System.Security.Cryptography.X509Certificates;
126
using System.Text;
13-
using System.Threading;
14-
using System.Threading.Tasks;
157
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
168

179
namespace EntityFrameworkCore.Projectables.Generator
@@ -41,167 +33,127 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
4133
{
4234
// Do a simple filter for members
4335
IncrementalValuesProvider<MemberDeclarationSyntax> memberDeclarations = context.SyntaxProvider
44-
.CreateSyntaxProvider(
45-
predicate: static (s, _) => s is MemberDeclarationSyntax m && m.AttributeLists.Count > 0,
46-
transform: static (c, _) => GetSemanticTargetForGeneration(c))
47-
.Where(static m => m is not null)!; // filter out attributed enums that we don't care about
36+
.ForAttributeWithMetadataName(
37+
ProjectablesAttributeName,
38+
predicate: static (s, _) => s is MemberDeclarationSyntax,
39+
transform: static (c, _) => (MemberDeclarationSyntax)c.TargetNode)
40+
.WithComparer(new MemberDeclarationSyntaxEqualityComparer());
4841

4942
// Combine the selected enums with the `Compilation`
50-
IncrementalValueProvider<(Compilation, ImmutableArray<MemberDeclarationSyntax>)> compilationAndEnums
51-
= context.CompilationProvider.Combine(memberDeclarations.Collect());
43+
IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations
44+
.Combine(context.CompilationProvider)
45+
.WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer());
5246

5347
// Generate the source using the compilation and enums
54-
context.RegisterImplementationSourceOutput(compilationAndEnums,
48+
context.RegisterImplementationSourceOutput(compilationAndMemberPairs,
5549
static (spc, source) => Execute(source.Item1, source.Item2, spc));
5650
}
5751

58-
static MemberDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
52+
static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context)
5953
{
60-
// we know the node is a MemberDeclarationSyntax
61-
var memberDeclarationSyntax = (MemberDeclarationSyntax)context.Node;
54+
var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context);
6255

63-
// loop through all the attributes on the method
64-
foreach (var attributeListSyntax in memberDeclarationSyntax.AttributeLists)
56+
if (projectable is null)
6557
{
66-
foreach (var attributeSyntax in attributeListSyntax.Attributes)
67-
{
68-
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is not IMethodSymbol attributeSymbol)
69-
{
70-
// weird, we couldn't get the symbol, ignore it
71-
continue;
72-
}
73-
74-
var attributeContainingTypeSymbol = attributeSymbol.ContainingType;
75-
var fullName = attributeContainingTypeSymbol.ToDisplayString();
76-
77-
// Is the attribute the [Projcetable] attribute?
78-
if (fullName == ProjectablesAttributeName)
79-
{
80-
// return the enum
81-
return memberDeclarationSyntax;
82-
}
83-
}
84-
}
85-
86-
// we didn't find the attribute we were looking for
87-
return null;
88-
}
89-
90-
static void Execute(Compilation compilation, ImmutableArray<MemberDeclarationSyntax> members, SourceProductionContext context)
91-
{
92-
if (members.IsDefaultOrEmpty)
93-
{
94-
// nothing to do yet
9558
return;
9659
}
9760

98-
var projectables = members
99-
.Select(x => ProjectableInterpreter.GetDescriptor(compilation, x, context))
100-
.Where(x => x is not null)
101-
.Select(x => x!);
102-
103-
var resultBuilder = new StringBuilder();
104-
105-
foreach (var projectable in projectables)
61+
if (projectable.MemberName is null)
10662
{
107-
if (projectable.MemberName is null)
108-
{
109-
throw new InvalidOperationException("Expected a memberName here");
110-
}
63+
throw new InvalidOperationException("Expected a memberName here");
64+
}
11165

112-
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
113-
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";
66+
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
67+
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";
11468

115-
var classSyntax = ClassDeclaration(generatedClassName)
116-
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
117-
.WithTypeParameterList(projectable.ClassTypeParameterList)
118-
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
119-
.AddAttributeLists(
120-
AttributeList()
121-
.AddAttributes(_editorBrowsableAttribute)
122-
)
123-
.AddMembers(
124-
MethodDeclaration(
125-
GenericName(
126-
Identifier("global::System.Linq.Expressions.Expression"),
127-
TypeArgumentList(
128-
SingletonSeparatedList(
129-
(TypeSyntax)GenericName(
130-
Identifier("global::System.Func"),
131-
GetLambdaTypeArgumentListSyntax(projectable)
132-
)
69+
var classSyntax = ClassDeclaration(generatedClassName)
70+
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
71+
.WithTypeParameterList(projectable.ClassTypeParameterList)
72+
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
73+
.AddAttributeLists(
74+
AttributeList()
75+
.AddAttributes(_editorBrowsableAttribute)
76+
)
77+
.AddMembers(
78+
MethodDeclaration(
79+
GenericName(
80+
Identifier("global::System.Linq.Expressions.Expression"),
81+
TypeArgumentList(
82+
SingletonSeparatedList(
83+
(TypeSyntax)GenericName(
84+
Identifier("global::System.Func"),
85+
GetLambdaTypeArgumentListSyntax(projectable)
13386
)
13487
)
135-
),
136-
"Expression"
137-
)
138-
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
139-
.WithTypeParameterList(projectable.TypeParameterList)
140-
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
141-
.WithBody(
142-
Block(
143-
ReturnStatement(
144-
ParenthesizedLambdaExpression(
145-
projectable.ParametersList ?? ParameterList(),
146-
null,
147-
projectable.ExpressionBody
148-
)
88+
)
89+
),
90+
"Expression"
91+
)
92+
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
93+
.WithTypeParameterList(projectable.TypeParameterList)
94+
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
95+
.WithBody(
96+
Block(
97+
ReturnStatement(
98+
ParenthesizedLambdaExpression(
99+
projectable.ParametersList ?? ParameterList(),
100+
null,
101+
projectable.ExpressionBody
149102
)
150103
)
151-
)
152-
);
104+
)
105+
)
106+
);
153107

154108
#nullable disable
155109

156-
var compilationUnit = CompilationUnit();
110+
var compilationUnit = CompilationUnit();
157111

158-
foreach (var usingDirective in projectable.UsingDirectives)
159-
{
160-
compilationUnit = compilationUnit.AddUsings(usingDirective);
161-
}
112+
foreach (var usingDirective in projectable.UsingDirectives)
113+
{
114+
compilationUnit = compilationUnit.AddUsings(usingDirective);
115+
}
162116

163-
if (projectable.ClassNamespace is not null)
164-
{
165-
compilationUnit = compilationUnit.AddUsings(
166-
UsingDirective(
167-
ParseName(projectable.ClassNamespace)
168-
)
169-
);
170-
}
117+
if (projectable.ClassNamespace is not null)
118+
{
119+
compilationUnit = compilationUnit.AddUsings(
120+
UsingDirective(
121+
ParseName(projectable.ClassNamespace)
122+
)
123+
);
124+
}
171125

172-
compilationUnit = compilationUnit
173-
.AddMembers(
174-
NamespaceDeclaration(
175-
ParseName("EntityFrameworkCore.Projectables.Generated")
176-
).AddMembers(classSyntax)
126+
compilationUnit = compilationUnit
127+
.AddMembers(
128+
NamespaceDeclaration(
129+
ParseName("EntityFrameworkCore.Projectables.Generated")
130+
).AddMembers(classSyntax)
131+
)
132+
.WithLeadingTrivia(
133+
TriviaList(
134+
Comment("// <auto-generated/>"),
135+
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
177136
)
178-
.WithLeadingTrivia(
179-
TriviaList(
180-
Comment("// <auto-generated/>"),
181-
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
182-
)
183-
);
137+
);
184138

185139

186-
context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));
140+
context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));
187141

142+
static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
143+
{
144+
var lambdaTypeArguments = TypeArgumentList(
145+
SeparatedList(
146+
// TODO: Document where clause
147+
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
148+
)
149+
);
188150

189-
static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
151+
if (projectable.ReturnTypeName is not null)
190152
{
191-
var lambdaTypeArguments = TypeArgumentList(
192-
SeparatedList(
193-
// TODO: Document where clause
194-
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
195-
)
196-
);
197-
198-
if (projectable.ReturnTypeName is not null)
199-
{
200-
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
201-
}
202-
203-
return lambdaTypeArguments;
153+
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
204154
}
155+
156+
return lambdaTypeArguments;
205157
}
206158
}
207159
}

0 commit comments

Comments
 (0)