Skip to content

Commit b494b33

Browse files
authored
Generate parameters list using Roslyn API (#24)
1 parent 8c816bd commit b494b33

File tree

6 files changed

+65
-40
lines changed

6 files changed

+65
-40
lines changed

SqlMarshal.Tests/AsyncSqlConnectionTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ namespace Foo
249249
250250
partial class C
251251
{
252-
public partial async Task<Item?> M()
252+
public partial async System.Threading.Tasks.Task<Item?> M()
253253
{
254254
var connection = this.dbContext.Database.GetDbConnection();
255255
using var command = connection.CreateCommand();
@@ -311,7 +311,7 @@ namespace Foo
311311
312312
partial class C
313313
{
314-
public partial async Task<Foo.Item?> M()
314+
public partial async System.Threading.Tasks.Task<Foo.Item?> M()
315315
{
316316
var connection = this.connection;
317317
using var command = connection.CreateCommand();

SqlMarshal.Tests/CodeGenerationTestBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ protected string GetCSharpGeneratedOutput(string source, NullableContextOptions
5757

5858
protected string GetVisualBasicGeneratedOutput(string source)
5959
{
60-
var syntaxTree = VisualBasicSyntaxTree.ParseText(source);
60+
var syntaxTree = VisualBasicSyntaxTree.ParseText("Imports SqlMarshal.Annotations\r\n" + source);
6161

6262
var references = new List<MetadataReference>();
6363
Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies();

SqlMarshal/AbstractGenerator.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ internal static IEnumerable<string> GetUsings(ClassGenerationContext classGenera
101101
}
102102

103103
/// <summary>
104-
/// Gets code for parameter declaration.
104+
/// Gets parameters list.
105105
/// </summary>
106-
/// <param name="methodSymbol">Method symbol from which we copy parameter.</param>
107-
/// <param name="parameter">Parameter to copy.</param>
108-
/// <param name="index">Index of parameter to copy.</param>
109-
/// <returns>Generated syntax node for the parameter.</returns>
110-
protected abstract SyntaxNode GetParameterDeclaration(IMethodSymbol methodSymbol, IParameterSymbol parameter, int index);
106+
/// <param name="methodSymbol">Method symbol from which we copy parameters.</param>
107+
/// <returns>Generated syntax node for the parameters list.</returns>
108+
protected abstract SyntaxNode GetParameters(IMethodSymbol methodSymbol);
111109

112110
private static string GetAccessibility(Accessibility a)
113111
{
@@ -917,7 +915,7 @@ private void ProcessMethod(
917915
var originalParameters = methodSymbol.Parameters;
918916

919917
bool hasCustomSql = methodGenerationContext.CustomSqlParameter != null;
920-
var signature = $"({string.Join(", ", originalParameters.Select((parameterSymbol, index) => this.GetParameterDeclaration(methodSymbol, parameterSymbol, index)))})";
918+
var signature = this.GetParameters(methodSymbol).ToString();
921919
var itemType = methodGenerationContext.ItemType;
922920
var getConnection = this.GetConnectionStatement(methodGenerationContext);
923921
var isList = methodGenerationContext.IsList || methodGenerationContext.IsEnumerable;
@@ -929,7 +927,8 @@ private void ProcessMethod(
929927
{
930928
if (methodSymbol.ReturnType.Name == "Task")
931929
{
932-
returnTypeName = "Task<" + returnType + "?>";
930+
var x = methodGenerationContext.ClassGenerationContext.CreateTaskType(returnType);
931+
returnTypeName = x.ToString();
933932
}
934933
else if (!methodGenerationContext.IsDataReader)
935934
{

SqlMarshal/CSharpGenerator.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,16 @@ public override void Initialize(GeneratorInitializationContext context)
7070
}
7171

7272
/// <inheritdoc/>
73-
protected override SyntaxNode GetParameterDeclaration(IMethodSymbol methodSymbol, IParameterSymbol parameter, int index)
73+
protected override SyntaxNode GetParameters(IMethodSymbol methodSymbol)
74+
{
75+
var parametersNodes = methodSymbol.Parameters.Select((parameterSymbol, index) => GetParameterDeclaration(methodSymbol, parameterSymbol, index));
76+
var separatedList = methodSymbol.Parameters.Length == 0
77+
? SeparatedList<ParameterSyntax>()
78+
: SeparatedList(parametersNodes, methodSymbol.Parameters.Take(methodSymbol.Parameters.Length - 1).Select(_ => Token(SyntaxKind.CommaToken).WithTrailingTrivia(Whitespace(" "))));
79+
return ParameterList(separatedList);
80+
}
81+
82+
private static ParameterSyntax GetParameterDeclaration(IMethodSymbol methodSymbol, IParameterSymbol parameter, int index)
7483
{
7584
var typeAsClause = ParseTypeName(parameter.Type.ToDisplayString()).WithTrailingTrivia(Whitespace(" "));
7685
if (parameter.RefKind == RefKind.Out)

SqlMarshal/ClassGenerationContext.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
namespace SqlMarshal;
88

99
using System.Collections.Generic;
10+
using System.Collections.Immutable;
1011
using System.Linq;
1112
using Microsoft.CodeAnalysis;
1213
using static SqlMarshal.Extensions;
@@ -59,6 +60,13 @@ public ClassGenerationContext(
5960

6061
public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => (_.IsList || _.IsEnumerable) && (IsScalarType(_.ItemType) || IsTuple(_.ItemType)));
6162

63+
public INamedTypeSymbol CreateTaskType(ITypeSymbol nestedType)
64+
{
65+
var taskType = this.GeneratorExecutionContext.Compilation.GetTypeByMetadataName($"System.Threading.Tasks.Task`1")!;
66+
var taskedType = taskType.Construct(ImmutableArray.Create(nestedType), ImmutableArray.Create(nestedType.NullableAnnotation == NullableAnnotation.None ? NullableAnnotation.Annotated : nestedType.NullableAnnotation));
67+
return taskedType;
68+
}
69+
6270
private static IFieldSymbol? GetConnectionField(INamedTypeSymbol classSymbol)
6371
{
6472
var fieldSymbols = classSymbol.GetMembers().OfType<IFieldSymbol>();

SqlMarshal/VisualBasicGenerator.cs

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,39 @@ public class VisualBasicGenerator : AbstractGenerator
2727
2828
Namespace SqlMarshal.Annotations
2929
30-
<System.AttributeUsage(System.AttributeTargets.Method, AllowMultiple:=True)>
31-
Friend NotInheritable Class SqlMarshalAttribute
32-
Inherits System.Attribute
30+
<System.AttributeUsage(System.AttributeTargets.Method, AllowMultiple:=True)>
31+
Friend NotInheritable Class SqlMarshalAttribute
32+
Inherits System.Attribute
3333
34-
Public Sub New()
35-
StoredProcedureName = """"
36-
End Sub
34+
Public Sub New()
35+
StoredProcedureName = """"
36+
End Sub
3737
38-
Public Sub New(name As String)
39-
StoredProcedureName = name
40-
End Sub
38+
Public Sub New(name As String)
39+
StoredProcedureName = name
40+
End Sub
4141
42-
Public Property StoredProcedureName As String
43-
End Class
42+
Public Property StoredProcedureName As String
43+
End Class
4444
45-
<System.AttributeUsage(System.AttributeTargets.Parameter, AllowMultiple:=False)>
46-
Friend NotInheritable Class RawSqlAttribute
47-
Inherits System.Attribute
45+
<System.AttributeUsage(System.AttributeTargets.Parameter, AllowMultiple:=False)>
46+
Friend NotInheritable Class RawSqlAttribute
47+
Inherits System.Attribute
4848
49-
Public Sub New()
50-
End Sub
51-
End Class
49+
Public Sub New()
50+
End Sub
51+
End Class
5252
53-
<System.AttributeUsage(System.AttributeTargets.Class, AllowMultiple:=False)>
54-
Friend NotInheritable Class RepositoryAttribute
55-
Inherits System.Attribute
53+
<System.AttributeUsage(System.AttributeTargets.Class, AllowMultiple:=False)>
54+
Friend NotInheritable Class RepositoryAttribute
55+
Inherits System.Attribute
5656
57-
Public Sub New(entityType As System.Type)
58-
EntityType = entityType
59-
End Sub
57+
Public Sub New(entityType As System.Type)
58+
EntityType = entityType
59+
End Sub
6060
61-
Public Property EntityType As System.Type
62-
End Class
61+
Public Property EntityType As System.Type
62+
End Class
6363
6464
End Namespace
6565
";
@@ -75,7 +75,16 @@ public override void Initialize(GeneratorInitializationContext context)
7575
}
7676

7777
/// <inheritdoc/>
78-
protected override SyntaxNode GetParameterDeclaration(IMethodSymbol methodSymbol, IParameterSymbol parameter, int index)
78+
protected override SyntaxNode GetParameters(IMethodSymbol methodSymbol)
79+
{
80+
var parametersNodes = methodSymbol.Parameters.Select((parameterSymbol, index) => GetParameterDeclaration(methodSymbol, parameterSymbol, index));
81+
var separatedList = methodSymbol.Parameters.Length == 0
82+
? SeparatedList<ParameterSyntax>()
83+
: SeparatedList(parametersNodes, methodSymbol.Parameters.Take(methodSymbol.Parameters.Length - 1).Select(_ => Token(SyntaxKind.CommaToken).WithTrailingTrivia(Whitespace(" "))));
84+
return ParameterList(separatedList);
85+
}
86+
87+
private static ParameterSyntax GetParameterDeclaration(IMethodSymbol methodSymbol, IParameterSymbol parameter, int index)
7988
{
8089
var typeAsClause = SimpleAsClause(ParseTypeName(parameter.Type.ToDisplayString()).WithLeadingTrivia(Whitespace(" ")));
8190
if (parameter.RefKind == RefKind.Out)
@@ -111,12 +120,12 @@ public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
111120
return;
112121
}
113122

114-
if (methodSymbol.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "SqlMarshalAttribute"))
123+
if (methodSymbol.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "SqlMarshal.Annotations.SqlMarshalAttribute"))
115124
{
116125
this.Methods.Add(methodSymbol);
117126
}
118127

119-
if (methodSymbol.ContainingType.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "RepositoryAttribute"))
128+
if (methodSymbol.ContainingType.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "SqlMarshal.Annotations.RepositoryAttribute"))
120129
{
121130
this.Methods.Add(methodSymbol);
122131
}

0 commit comments

Comments
 (0)