Skip to content

Commit 5bc694c

Browse files
committed
support command definitions
1 parent 4f8cd50 commit 5bc694c

File tree

11 files changed

+533
-44
lines changed

11 files changed

+533
-44
lines changed

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private void ValidateDapperMethod(in OperationAnalysisContext ctx, IOperation sq
192192
var parseState = new ParseState(ctx);
193193
bool aotEnabled = IsEnabled(in parseState, invoke, Types.DapperAotAttribute, out var aotAttribExists);
194194
if (!aotEnabled) flags |= OperationFlags.DoNotGenerate;
195-
var location = SharedParseArgsAndFlags(parseState, invoke, ref flags, out var sql, out var argExpression, onDiagnostic, out _, exitFirstFailure: false);
195+
var location = SharedParseArgsAndFlags(parseState, invoke, ref flags, out var sql, out var argExpression, onDiagnostic, out _, exitFirstFailure: false, out var viaCommandDefinition);
196196

197197
// report our AOT readiness
198198
if (aotEnabled)
@@ -410,7 +410,7 @@ private void ValidateSql(in OperationAnalysisContext ctx, IOperation sqlSource,
410410
if (caseSensitive) flags |= SqlParseInputFlags.CaseSensitive;
411411

412412
// can we get the SQL itself?
413-
if (!TryGetConstantValueWithSyntax(sqlSource, out string? sql, out var sqlSyntax, out var stringSyntaxKind))
413+
if (!TryGetStringConstantValueWithSyntax(sqlSource, out string? sql, out var sqlSyntax, out var stringSyntaxKind))
414414
{
415415
DiagnosticDescriptor? descriptor = stringSyntaxKind switch
416416
{
@@ -503,32 +503,70 @@ StringSyntaxKind.ConcatenatedString or StringSyntaxKind.FormatString
503503

504504
// we want a common understanding of the setup between the analyzer and generator
505505
internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationOperation op, ref OperationFlags flags, out string? sql,
506-
out IOperation? argExpression, Action<Diagnostic>? reportDiagnostic, out ITypeSymbol? resultType, bool exitFirstFailure)
506+
out IOperation? argExpression, Action<Diagnostic>? reportDiagnostic, out ITypeSymbol? resultType, bool exitFirstFailure,
507+
out bool viaCommandDefinition)
507508
{
508509
var callLocation = op.GetMemberLocation();
509510
argExpression = null;
510511
sql = null;
511512
bool? buffered = null;
513+
viaCommandDefinition = false;
512514

513-
// check the args
514-
foreach (var arg in op.Arguments)
515+
// default is invocation, so simply take arguments
516+
IEnumerable<IArgumentOperation> arguments = op.Arguments;
517+
518+
// invocation can be packed into a CommandDefinition
519+
if (op.Arguments is { Length: 2 })
515520
{
521+
if (op.Arguments[0].Parameter?.Name == "cnn"
522+
&& op.Arguments[1].Parameter?.Name == "command" && op.Arguments[1].Parameter?.Type.IsDapperType("CommandDefinition") == true)
523+
{
524+
viaCommandDefinition = true;
525+
526+
// by default buffered CommandDefinition constructor initializes `buffered` as true via CommandFlags
527+
// https://github.com/DapperLib/Dapper/blob/5c7143f2e3585d4708294a3b0530a134e18ace86/Dapper/CommandDefinition.cs#L85
528+
buffered = true;
529+
530+
// in-place creation of CommandDefinition like `Query<T>(new CommandDefinition(...))`
531+
if (op.Arguments[1].Value is IObjectCreationOperation { Arguments.IsDefaultOrEmpty: false } commandDefinitionCreation )
532+
{
533+
arguments = commandDefinitionCreation.Arguments;
534+
}
516535

536+
// ideally here we would want to parse other CommandDefinition cases (i.e. local variable).
537+
// but it is complicated, so we can simply rely on passing CommandDefinition's members to the underlying query API
538+
// ...
539+
}
540+
}
541+
542+
// check the args. Names of the parameters are handling Dapper method parameters + CommandDefinition members
543+
foreach (var arg in arguments)
544+
{
517545
switch (arg.Parameter?.Name)
518546
{
519547
case "sql":
520-
if (TryGetConstantValueWithSyntax(arg, out string? s, out _, out _))
548+
case "commandText":
549+
if (TryGetStringConstantValueWithSyntax(arg, out string? s, out _, out _))
521550
{
522551
sql = s;
523552
}
524553
break;
554+
case "flags":
555+
{
556+
if (TryGetEnumConstantValueWithSyntax(arg, out int? value))
557+
{
558+
buffered = (value & 1) != 0; // CommandFlags.Buffered = 1
559+
}
560+
}
561+
break;
525562
case "buffered":
526563
if (TryGetConstantValue(arg, out bool b))
527564
{
528565
buffered = b;
529566
}
530567
break;
531568
case "param":
569+
case "parameters":
532570
if (arg.Value is not IDefaultValueOperation)
533571
{
534572
var expr = arg.Value;
@@ -555,6 +593,7 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
555593
case "length":
556594
case "returnNullIfFirstMissing":
557595
case "concreteType" when arg.Value is IDefaultValueOperation || (arg.ConstantValue.HasValue && arg.ConstantValue.Value is null):
596+
case "cancellationToken":
558597
// nothing to do
559598
break;
560599
case "commandType":
@@ -583,6 +622,17 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
583622
}
584623
}
585624
break;
625+
case "command":
626+
{
627+
// case for CommandDefinition - we need to check that we detected it correctly before
628+
// and if we did; then don't drop errors - we could not parse SQL / other flags in complex CommandDefinition usages,
629+
// but we can optimistically pass CommandDefinition data to underlying query
630+
if (!viaCommandDefinition)
631+
{
632+
goto default;
633+
}
634+
}
635+
break;
586636
default:
587637
if (!flags.HasAny(OperationFlags.NotAotSupported | OperationFlags.DoNotGenerate))
588638
{

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,38 @@ static void WriteSingleImplementation(
1919
in CommandFactoryState factories,
2020
in RowReaderState readers,
2121
string? fixedSql,
22-
AdditionalCommandState? additionalCommandState)
22+
AdditionalCommandState? additionalCommandState,
23+
bool viaCommandDefinition)
2324
{
2425
sb.Append("return ");
2526
if (flags.HasAll(OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
2627
{
2728
sb.Append("global::Dapper.DapperAotExtensions.AsEnumerableAsync(").Indent(false).NewLine();
2829
}
2930
// (DbConnection connection, DbTransaction? transaction, string sql, TArgs args, CommandType commandType, int timeout, CommandFactory<TArgs>? commandFactory)
30-
sb.Append("global::Dapper.DapperAotExtensions.Command(cnn, ").Append(Forward(methodParameters, "transaction")).Append(", ");
31+
sb.Append("global::Dapper.DapperAotExtensions.Command(cnn, ");
32+
33+
if (viaCommandDefinition) sb.Append("command.Transaction, ");
34+
else sb.Append(Forward(methodParameters, "transaction")).Append(", ");
35+
3136
if (fixedSql is not null)
3237
{
3338
sb.AppendVerbatimLiteral(fixedSql).Append(", ");
3439
}
3540
else
3641
{
37-
sb.Append("sql, ");
42+
if (viaCommandDefinition) sb.Append("command.CommandText, ");
43+
else sb.Append("sql, ");
3844
}
45+
3946
if (commandTypeMode == 0)
40-
{ // not hard-coded
41-
if (HasParam(methodParameters, "command"))
47+
{
48+
if (viaCommandDefinition)
49+
{
50+
sb.Append("command.CommandType ?? default");
51+
}
52+
// not hard-coded
53+
else if (HasParam(methodParameters, "command"))
4254
{
4355
sb.Append("command.GetValueOrDefault()");
4456
}
@@ -49,9 +61,14 @@ static void WriteSingleImplementation(
4961
}
5062
else
5163
{
52-
sb.Append("global::System.Data.CommandType.").Append(commandTypeMode.ToString());
64+
if (viaCommandDefinition) sb.Append("command.CommandType ?? default");
65+
else sb.Append("global::System.Data.CommandType.").Append(commandTypeMode.ToString());
5366
}
54-
sb.Append(", ").Append(Forward(methodParameters, "commandTimeout")).Append(HasParam(methodParameters, "commandTimeout") ? ".GetValueOrDefault()" : "").Append(", ");
67+
sb.Append(", ");
68+
69+
if (viaCommandDefinition) sb.Append("command.CommandTimeout ?? default, ");
70+
else sb.Append(Forward(methodParameters, "commandTimeout")).Append(HasParam(methodParameters, "commandTimeout") ? ".GetValueOrDefault()" : "").Append(", ");
71+
5572
if (flags.HasAny(OperationFlags.HasParameters))
5673
{
5774
var index = factories.GetIndex(parameterType!, map, cache, additionalCommandState, out var subIndex);
@@ -79,7 +96,7 @@ static void WriteSingleImplementation(
7996
OperationFlags.Unbuffered => "Unbuffered",
8097
_ => ""
8198
}).Append(isAsync ? "Async" : "").Append("(");
82-
WriteTypedArg(sb, parameterType).Append(", ");
99+
WriteTypedArg(sb, parameterType, viaCommandDefinition).Append(", ");
83100
if (!flags.HasAny(OperationFlags.SingleRow))
84101
{
85102
switch (flags & (OperationFlags.Buffered | OperationFlags.Unbuffered))
@@ -107,7 +124,7 @@ static void WriteSingleImplementation(
107124
sb.Append("<").Append(resultType).Append(">");
108125
}
109126
sb.Append("(");
110-
WriteTypedArg(sb, parameterType);
127+
WriteTypedArg(sb, parameterType, viaCommandDefinition);
111128
}
112129
else
113130
{
@@ -124,9 +141,11 @@ static void WriteSingleImplementation(
124141
sb.Append(", rowCountHint: ((").Append(parameterType).Append(")param!).").Append(additionalCommandState.RowCountHintMemberName);
125142
}
126143
}
127-
if (isAsync && HasParam(methodParameters, "cancellationToken"))
144+
if (isAsync && (HasParam(methodParameters, "cancellationToken") || viaCommandDefinition))
128145
{
129-
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
146+
sb.Append(", cancellationToken: ");
147+
if (viaCommandDefinition) sb.Append("command.CancellationToken");
148+
else sb.Append(Forward(methodParameters, "cancellationToken"));
130149
}
131150
if (flags.HasAll(OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
132151
{
@@ -153,10 +172,17 @@ static void WriteSingleImplementation(
153172
}
154173
sb.Append(";").NewLine();
155174

156-
static CodeWriter WriteTypedArg(CodeWriter sb, ITypeSymbol? parameterType)
175+
static CodeWriter WriteTypedArg(CodeWriter sb, ITypeSymbol? parameterType, bool viaCommandDefinition)
157176
{
177+
if (viaCommandDefinition)
178+
{
179+
sb.Append("command.Parameters");
180+
return sb;
181+
}
182+
158183
if (parameterType is null || parameterType.IsAnonymousType)
159184
{
185+
160186
sb.Append("param");
161187
}
162188
else

0 commit comments

Comments
 (0)