Skip to content

Commit f540ada

Browse files
committed
add [StrictBind(...)]
1 parent 7566e7d commit f540ada

File tree

15 files changed

+1001
-89
lines changed

15 files changed

+1001
-89
lines changed

docs/rules/DAP049.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# DAP049
2+
3+
When using `[StrictBind(...)]`, the elements should be the member names on the corresponding type. This error simply means that Dapper
4+
could not find a member you specified. You can skip unwanted columns by passing `null` or `""`.

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public static readonly DiagnosticDescriptor
5656
AmbiguousProperties = LibraryWarning("DAP046", "Ambiguous properties", "Properties have same name '{0}' after normalization and can be conflated"),
5757
AmbiguousFields = LibraryWarning("DAP047", "Ambiguous fields", "Fields have same name '{0}' after normalization and can be conflated"),
5858
MoveFromDbString = LibraryWarning("DAP048", "Move from DbString to DbValue", "DbString achieves the same as [DbValue] does. Use it instead."),
59+
BoundMemberNotFound = LibraryError("DAP049", "Bound member not found", "The bound member '{0}' was not found"),
5960

6061
// SQL parse specific
6162
GeneralSqlError = SqlWarning("DAP200", "SQL error", "SQL error: {0}"),
@@ -103,7 +104,6 @@ public static readonly DiagnosticDescriptor
103104
ConcatenatedStringSqlExpression = SqlWarning("DAP242", "Concatenated string usage", "Data values should not be concatenated into SQL string - use parameters instead"),
104105
InvalidDatepartToken = SqlWarning("DAP243", "Valid datepart token expected", "Date functions require a recognized datepart argument"),
105106
SelectAggregateMismatch = SqlWarning("DAP244", "SELECT aggregate mismatch", "SELECT has mixture of aggregate and non-aggregate expressions"),
106-
PseudoPositionalParameter = SqlError("DAP245", "Avoid SQL pseudo-positional parameter", "It is more like Dapper will incorrectly treat this literal as a pseudo-positional parameter")
107-
;
107+
PseudoPositionalParameter = SqlError("DAP245", "Avoid SQL pseudo-positional parameter", "It is more like Dapper will incorrectly treat this literal as a pseudo-positional parameter");
108108
}
109109
}

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,12 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
665665
}
666666
}
667667

668+
if (flags.HasAll(OperationFlags.BindResultsByName) && GetClosestDapperAttribute(ctx, op, Types.StrictBindAttribute) is not null)
669+
{
670+
flags |= OperationFlags.StrictBind;
671+
}
672+
673+
668674
if (exitFirstFailure && flags.HasAny(OperationFlags.DoNotGenerate))
669675
{
670676
resultType = null;
@@ -775,6 +781,7 @@ enum ParameterMode
775781
}
776782
}
777783

784+
ImmutableArray<string> strictBind = default;
778785
int? batchSize = null;
779786
foreach (var attrib in methodAttribs)
780787
{
@@ -813,6 +820,9 @@ enum ParameterMode
813820
batchSize = batchTmp;
814821
}
815822
break;
823+
case Types.StrictBindAttribute:
824+
strictBind = ParseStrictBindColumns(attrib);
825+
break;
816826
}
817827
}
818828
}
@@ -841,8 +851,8 @@ enum ParameterMode
841851
}
842852

843853

844-
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null
845-
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps);
854+
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null && strictBind.IsDefault
855+
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps, strictBind);
846856
}
847857

848858
static void ValidateParameters(MemberMap? parameters, OperationFlags flags, Action<Diagnostic> onDiagnostic)

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void WriteSingleImplementation(
9292
break;
9393
}
9494
}
95-
sb.AppendReader(resultType, readers);
95+
sb.AppendReader(resultType, readers, additionalCommandState?.StrictBind ?? default);
9696
}
9797
else if (flags.HasAny(OperationFlags.Execute))
9898
{

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs

Lines changed: 141 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ internal void Generate(in GenerateState ctx)
391391

392392
if (flags.HasAny(OperationFlags.GetRowParser))
393393
{
394-
WriteGetRowParser(sb, resultType, readers);
394+
WriteGetRowParser(sb, resultType, readers, grp.Key.AdditionalCommandState?.StrictBind ?? default);
395395
}
396396
else if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql, additionalCommandState))
397397
{
@@ -449,9 +449,9 @@ internal void Generate(in GenerateState ctx)
449449
}
450450
sb.NewLine();
451451

452-
foreach (var pair in readers)
452+
foreach (var tuple in readers)
453453
{
454-
WriteRowFactory(ctx, sb, pair.Type, pair.Index);
454+
WriteRowFactory(ctx, sb, tuple.Type, tuple.Index, tuple.StrictBind, null /* TODO */);
455455
}
456456

457457
foreach (var tuple in factories)
@@ -468,9 +468,9 @@ internal void Generate(in GenerateState ctx)
468468
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.InterceptorsGenerated, null, callSiteCount, ctx.Nodes.Length, methodIndex, factories.Count(), readers.Count()));
469469
}
470470

471-
private static void WriteGetRowParser(CodeWriter sb, ITypeSymbol? resultType, in RowReaderState readers)
471+
private static void WriteGetRowParser(CodeWriter sb, ITypeSymbol? resultType, in RowReaderState readers, ImmutableArray<string> strictBind)
472472
{
473-
sb.Append("return ").AppendReader(resultType, readers)
473+
sb.Append("return ").AppendReader(resultType, readers, strictBind)
474474
.Append(".GetRowParser(reader, startIndex, length, returnNullIfFirstMissing);").NewLine();
475475
}
476476

@@ -732,7 +732,7 @@ static bool IsReserved(string name)
732732
}
733733
}
734734

735-
private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index)
735+
private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index, ImmutableArray<string> strictBind, Location? location)
736736
{
737737
var map = MemberMap.CreateForResults(type);
738738
if (map is null) return;
@@ -761,7 +761,7 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy
761761
WriteRowFactoryHeader();
762762

763763
WriteTokenizeMethod();
764-
WriteReadMethod();
764+
WriteReadMethod(context);
765765

766766
WriteRowFactoryFooter();
767767

@@ -780,31 +780,57 @@ void WriteRowFactoryFooter()
780780
void WriteTokenizeMethod()
781781
{
782782
sb.Append("public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)").Indent().NewLine();
783-
sb.Append("for (int i = 0; i < tokens.Length; i++)").Indent().NewLine()
784-
.Append("int token = -1;").NewLine()
785-
.Append("var name = reader.GetName(columnOffset);").NewLine()
786-
.Append("var type = reader.GetFieldType(columnOffset);").NewLine()
787-
.Append("switch (NormalizedHash(name))").Indent().NewLine();
783+
if (strictBind.IsDefault) // don't emit any tokens for strict binding
784+
{
785+
sb.Append("for (int i = 0; i < tokens.Length; i++)").Indent().NewLine()
786+
.Append("int token = -1;").NewLine()
787+
.Append("var name = reader.GetName(columnOffset);").NewLine()
788+
.Append("var type = reader.GetFieldType(columnOffset);").NewLine()
789+
.Append("switch (NormalizedHash(name))").Indent().NewLine();
788790

789-
int token = 0;
790-
foreach (var member in members)
791-
{
792-
var dbName = member.DbName;
793-
sb.Append("case ").Append(StringHashing.NormalizedHash(dbName))
794-
.Append(" when NormalizedEquals(name, ")
795-
.AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine()
796-
.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(token)
797-
.Append(" : ").Append(token + map.Members.Length).Append(";")
798-
.Append(token == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine()
799-
.Append("break;").Outdent(false).NewLine();
800-
token++;
791+
int token = 0;
792+
foreach (var member in members)
793+
{
794+
var dbName = member.DbName;
795+
sb.Append("case ").Append(StringHashing.NormalizedHash(dbName))
796+
.Append(" when NormalizedEquals(name, ")
797+
.AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine()
798+
.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(token)
799+
.Append(" : ").Append(token + map.Members.Length).Append(";")
800+
.Append(token == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine()
801+
.Append("break;").Outdent(false).NewLine();
802+
token++;
803+
}
804+
sb.Outdent().NewLine()
805+
.Append("tokens[i] = token;").NewLine()
806+
.Append("columnOffset++;").NewLine()
807+
.Outdent().NewLine();
801808
}
802-
sb.Outdent().NewLine()
803-
.Append("tokens[i] = token;").NewLine()
804-
.Append("columnOffset++;").NewLine();
805-
sb.Outdent().NewLine().Append("return null;").Outdent().NewLine();
809+
else
810+
{
811+
sb.Append("// strict-bind: ");
812+
for (int i = 0; i < strictBind.Length; i++)
813+
{
814+
if (i != 0) sb.Append(", ");
815+
var name = strictBind[i];
816+
if (string.IsNullOrWhiteSpace(name))
817+
{
818+
sb.Append("(n/a)");
819+
}
820+
else if (CompiledRegex.SimpleName.IsMatch(name))
821+
{
822+
sb.Append(name);
823+
}
824+
else
825+
{
826+
sb.Append("'").Append(name).Append("'");
827+
}
828+
}
829+
sb.NewLine().Append("global::System.Diagnostics.Debug.Assert(tokens.Length == ").Append(strictBind.Length).Append(""", "Strict-bind column count mismatch");""").NewLine();
830+
}
831+
sb.Append("return null;").Outdent().NewLine();
806832
}
807-
void WriteReadMethod()
833+
void WriteReadMethod(in GenerateState context)
808834
{
809835
const string DeferredConstructionVariableName = "value";
810836

@@ -852,47 +878,63 @@ void WriteReadMethod()
852878
? type.WithNullableAnnotation(NullableAnnotation.None) : type).Append(" result = new();").NewLine();
853879
}
854880

855-
sb.Append("foreach (var token in tokens)").Indent().NewLine()
856-
.Append("switch (token)").Indent().NewLine();
881+
ImmutableArray<ElementMember> readMembers;
882+
if (strictBind.IsDefault)
883+
{
884+
readMembers = members; // try to parse everything
885+
sb.Append("foreach (var token in tokens)");
886+
}
887+
else
888+
{
889+
readMembers = MapStrictBind(context, members, strictBind, location);
890+
sb.Append("for (int token = 0; token < tokens.Length; token++) // strict-bind");
891+
}
892+
sb.Indent().NewLine().Append("switch (token)").Indent().NewLine();
857893

858894
token = 0;
859-
foreach (var member in members)
895+
foreach (var member in readMembers)
860896
{
861-
var memberType = member.CodeType;
897+
if (member.Member is not null) // exclude non-mapped bindings
898+
{
899+
var memberType = member.CodeType;
862900

863-
member.GetDbType(out var readerMethod);
864-
var nullCheck = Inspection.CouldBeNullable(memberType) ? $"reader.IsDBNull(columnOffset) ? ({CodeWriter.GetTypeName(memberType.WithNullableAnnotation(NullableAnnotation.Annotated))})null : " : "";
865-
sb.Append("case ").Append(token).Append(":").NewLine().Indent(false);
901+
member.GetDbType(out var readerMethod);
902+
var nullCheck = Inspection.CouldBeNullable(memberType) ? $"reader.IsDBNull(columnOffset) ? ({CodeWriter.GetTypeName(memberType.WithNullableAnnotation(NullableAnnotation.Annotated))})null : " : "";
903+
sb.Append("case ").Append(token).Append(":").NewLine().Indent(false);
866904

867-
// write `result.X = ` or `member0 = `
868-
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
869-
else sb.Append("result.").Append(member.CodeName);
870-
sb.Append(" = ");
905+
// write `result.X = ` or `member0 = `
906+
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
907+
else sb.Append("result.").Append(member.CodeName);
908+
sb.Append(" = ");
871909

872-
sb.Append(nullCheck);
873-
if (readerMethod is null)
874-
{
875-
sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);");
876-
}
877-
else
878-
{
879-
sb.Append("reader.").Append(readerMethod).Append("(columnOffset);");
880-
}
910+
sb.Append(nullCheck);
911+
if (readerMethod is null)
912+
{
913+
sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);");
914+
}
915+
else
916+
{
917+
sb.Append("reader.").Append(readerMethod).Append("(columnOffset);");
918+
}
881919

882920

883-
sb.NewLine().Append("break;").NewLine().Outdent(false)
884-
.Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false);
921+
sb.NewLine().Append("break;").NewLine().Outdent(false);
885922

886-
// write `result.X = ` or `member0 = `
887-
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
888-
else sb.Append("result.").Append(member.CodeName);
923+
if (strictBind.IsDefault) // type-forgiving version; only emitted when not using strict-bind
924+
{
925+
sb.Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false);
889926

890-
sb.Append(" = ")
891-
.Append(nullCheck)
892-
.Append("GetValue<")
893-
.Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine()
894-
.Append("break;").NewLine().Outdent(false);
927+
// write `result.X = ` or `member0 = `
928+
if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token);
929+
else sb.Append("result.").Append(member.CodeName);
895930

931+
sb.Append(" = ")
932+
.Append(nullCheck)
933+
.Append("GetValue<")
934+
.Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine()
935+
.Append("break;").NewLine().Outdent(false);
936+
}
937+
}
896938
token++;
897939
}
898940

@@ -978,6 +1020,46 @@ void WriteDeferredMethodArgs()
9781020
}
9791021
}
9801022

1023+
private static ImmutableArray<ElementMember> MapStrictBind(in GenerateState state, ImmutableArray<ElementMember> members, ImmutableArray<string> strictBind, Location? location)
1024+
{
1025+
if (strictBind.IsDefault) return members; // not bound
1026+
1027+
var result = ImmutableArray.CreateBuilder<ElementMember>(strictBind.Length);
1028+
foreach (var seek in strictBind)
1029+
{
1030+
ElementMember found = default;
1031+
if (!string.IsNullOrWhiteSpace(seek))
1032+
{
1033+
foreach (var member in members)
1034+
{
1035+
if (member.CodeName == seek)
1036+
{
1037+
found = member;
1038+
break;
1039+
}
1040+
}
1041+
if (found.Member is null)
1042+
{
1043+
var normalizedSeek = StringHashing.Normalize(seek);
1044+
foreach (var member in members)
1045+
{
1046+
if (StringHashing.NormalizedEquals(member.CodeName, normalizedSeek))
1047+
{
1048+
found = member;
1049+
break;
1050+
}
1051+
}
1052+
}
1053+
if (found.Member is null)
1054+
{
1055+
state.ReportDiagnostic(Diagnostic.Create(DapperAnalyzer.Diagnostics.BoundMemberNotFound, location, seek));
1056+
}
1057+
}
1058+
result.Add(found);
1059+
}
1060+
return result.ToImmutable();
1061+
}
1062+
9811063
[Flags]
9821064
enum WriteArgsFlags
9831065
{

src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ internal sealed class AdditionalCommandState : IEquatable<AdditionalCommandState
3838
public readonly int? BatchSize;
3939
public readonly string? RowCountHintMemberName;
4040
public readonly ImmutableArray<CommandProperty> CommandProperties;
41+
public readonly ImmutableArray<string> StrictBind;
4142

4243
public bool HasRowCountHint => RowCountHint > 0 || RowCountHintMemberName is not null;
4344

@@ -74,7 +75,8 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited,
7475
}
7576

7677
return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize,
77-
Concat(inherited.CommandProperties, overrides.CommandProperties));
78+
Concat(inherited.CommandProperties, overrides.CommandProperties),
79+
overrides.StrictBind.IsDefault ? inherited.StrictBind : overrides.StrictBind);
7880
}
7981

8082
static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x, ImmutableArray<CommandProperty> y)
@@ -89,12 +91,13 @@ static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x,
8991

9092
internal AdditionalCommandState(
9193
int rowCountHint, string? rowCountHintMemberName, int? batchSize,
92-
ImmutableArray<CommandProperty> commandProperties)
94+
ImmutableArray<CommandProperty> commandProperties, ImmutableArray<string> strictBind)
9395
{
9496
RowCountHint = rowCountHint;
9597
RowCountHintMemberName = rowCountHintMemberName;
9698
BatchSize = batchSize;
9799
CommandProperties = commandProperties;
100+
StrictBind = strictBind;
98101
}
99102

100103

@@ -106,7 +109,8 @@ public bool Equals(in AdditionalCommandState other)
106109
=> RowCountHint == other.RowCountHint
107110
&& BatchSize == other.BatchSize
108111
&& RowCountHintMemberName == other.RowCountHintMemberName
109-
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties));
112+
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties))
113+
&& StrictBind.Equals(other.StrictBind);
110114

111115
private static bool Equals(in ImmutableArray<CommandProperty> x, in ImmutableArray<CommandProperty> y)
112116
{
@@ -145,5 +149,6 @@ static int GetHashCode(in ImmutableArray<CommandProperty> x)
145149
public override int GetHashCode()
146150
=> (RowCountHint + BatchSize.GetValueOrDefault()
147151
+ (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
148-
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties));
152+
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties))
153+
^ StrictBind.GetHashCode();
149154
}

0 commit comments

Comments
 (0)